• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/test.h"
23 
24 namespace xla {
25 namespace testing {
26 
27 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
28  public:
HloMatcher(HloOpcode opcode,std::vector<::testing::Matcher<const HloInstruction * >> operands)29   HloMatcher(HloOpcode opcode,
30              std::vector<::testing::Matcher<const HloInstruction*>> operands)
31       : opcode_(opcode), operands_(operands) {}
32 
33   bool MatchAndExplain(const HloInstruction* instruction,
34                        ::testing::MatchResultListener* listener) const override;
35 
36   void DescribeTo(::std::ostream* os) const override;
37 
38  private:
39   HloOpcode opcode_;
40   std::vector<::testing::Matcher<const HloInstruction*>> operands_;
41 };
42 
43 // Custom matcher for parameters, which accepts a parameter number.
44 class HloParameterMatcher : public HloMatcher {
45  public:
HloParameterMatcher(int64 parameter_number)46   explicit HloParameterMatcher(int64 parameter_number)
47       : HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
48         parameter_number_(parameter_number) {}
49 
50   bool MatchAndExplain(const HloInstruction* instruction,
51                        ::testing::MatchResultListener* listener) const override;
52 
53  private:
54   int64 parameter_number_;
55 };
56 
57 // Custom matcher for comparisons, which accepts a comparison direction.
58 class HloComparisonMatcher : public HloMatcher {
59  public:
HloComparisonMatcher(ComparisonDirection direction,std::vector<::testing::Matcher<const HloInstruction * >> operands)60   explicit HloComparisonMatcher(
61       ComparisonDirection direction,
62       std::vector<::testing::Matcher<const HloInstruction*>> operands)
63       : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
64 
65   bool MatchAndExplain(const HloInstruction* instruction,
66                        ::testing::MatchResultListener* listener) const override;
67 
68  private:
69   ComparisonDirection direction_;
70 };
71 
72 // Custom matcher for get-tuple-element instructions, which accepts a tuple
73 // index to match.
74 class HloGetTupleElementMatcher : public HloMatcher {
75  public:
HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)76   HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
77                             int64 tuple_index)
78       : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
79         tuple_index_(tuple_index) {}
80 
81   bool MatchAndExplain(const HloInstruction* instruction,
82                        ::testing::MatchResultListener* listener) const override;
83 
84  private:
85   int64 tuple_index_;
86 };
87 
88 // Custom matcher for custom-call instructions, which accepts a matcher for its
89 // call target.
90 class HloCustomCallMatcher : public HloMatcher {
91  public:
HloCustomCallMatcher(::testing::Matcher<string> call_target_matcher,std::vector<::testing::Matcher<const HloInstruction * >> operands)92   HloCustomCallMatcher(
93       ::testing::Matcher<string> call_target_matcher,
94       std::vector<::testing::Matcher<const HloInstruction*>> operands)
95       : HloMatcher(HloOpcode::kCustomCall, operands),
96         call_target_matcher_(call_target_matcher) {}
97 
98   bool MatchAndExplain(const HloInstruction* instruction,
99                        ::testing::MatchResultListener* listener) const override;
100   void DescribeTo(std::ostream* os) const override;
101 
102  private:
103   ::testing::Matcher<string> call_target_matcher_;
104 };
105 
106 class HloShapeMatcher
107     : public ::testing::MatcherInterface<const HloInstruction*> {
108  public:
HloShapeMatcher(const Shape & shape)109   explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
110 
111   bool MatchAndExplain(const HloInstruction* instruction,
112                        ::testing::MatchResultListener* listener) const override;
113   void DescribeTo(std::ostream* os) const override;
114 
115  private:
116   Shape shape_;
117 };
118 
119 class HloShapeAndLayoutMatcher
120     : public ::testing::MatcherInterface<const HloInstruction*> {
121  public:
HloShapeAndLayoutMatcher(const Shape & shape)122   explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {}
123 
124   bool MatchAndExplain(const HloInstruction* instruction,
125                        ::testing::MatchResultListener* listener) const override;
126   void DescribeTo(std::ostream* os) const override;
127 
128  private:
129   Shape shape_;
130 };
131 
132 // Verify the sharding of an instruction against the provided HloSharding. If a
133 // nullopt is provided for the expected sharding then it checks that no sharding
134 // is present for an instruction.
135 class HloShardingMatcher
136     : public ::testing::MatcherInterface<const HloInstruction*> {
137  public:
HloShardingMatcher(const absl::optional<HloSharding> & sharding)138   explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding)
139       : sharding_(sharding) {}
140 
141   bool MatchAndExplain(const HloInstruction* instruction,
142                        ::testing::MatchResultListener* listener) const override;
143   void DescribeTo(std::ostream* os) const override;
144 
145  private:
146   absl::optional<HloSharding> sharding_;
147 };
148 
149 // Matches a Dot HLO instruction with specific LHS and RHS contracting
150 // dimensions.
151 class HloDotWithContractingDimsMatcher : public HloMatcher {
152  public:
HloDotWithContractingDimsMatcher(::testing::Matcher<const HloInstruction * > lhs,::testing::Matcher<const HloInstruction * > rhs,int64 lhs_contracting_dim,int64 rhs_contracting_dim)153   explicit HloDotWithContractingDimsMatcher(
154       ::testing::Matcher<const HloInstruction*> lhs,
155       ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim,
156       int64 rhs_contracting_dim)
157       : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
158         lhs_contracting_dim_(lhs_contracting_dim),
159         rhs_contracting_dim_(rhs_contracting_dim) {}
160 
161   bool MatchAndExplain(const HloInstruction* instruction,
162                        ::testing::MatchResultListener* listener) const override;
163   void DescribeTo(std::ostream* os) const override;
164 
165  private:
166   int64 lhs_contracting_dim_;
167   int64 rhs_contracting_dim_;
168 };
169 
170 // Custom matcher for asynchronous copy (CopyStart/CopyDone pair) with specified
171 // source and destination memory spaces.
172 class HloAsyncCopyMatcher : public HloMatcher {
173  public:
HloAsyncCopyMatcher(int64 to_space,int64 from_space,::testing::Matcher<const HloInstruction * > operand)174   HloAsyncCopyMatcher(int64 to_space, int64 from_space,
175                       ::testing::Matcher<const HloInstruction*> operand)
176       : HloMatcher(HloOpcode::kCopyDone,
177                    {::testing::MakeMatcher(
178                        new HloMatcher(HloOpcode::kCopyStart, {operand}))}),
179         to_space_(to_space),
180         from_space_(from_space) {}
181 
182   bool MatchAndExplain(const HloInstruction* instruction,
183                        ::testing::MatchResultListener* listener) const override;
184   void DescribeTo(std::ostream* os) const override;
185 
186  private:
187   int64 to_space_;
188   int64 from_space_;
189 };
190 
191 // HloInstruction* matchers for opcode and operands. Example:
192 //   namespace op = xla::opcode_matchers;
193 //   EXPECT_THAT(instruction,
194 //               op::Add(op::Reshape(), op::Add(op::Reshape(), _)));
195 namespace opcode_matchers {
196 #define HLO_MATCHER(opcode)                                                \
197   template <typename... M>                                                 \
198   ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \
199     return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(          \
200         ::xla::HloOpcode::k##opcode, {operands...}));                      \
201   }
202 HLO_MATCHER(Abs);
203 HLO_MATCHER(Add);
204 HLO_MATCHER(AddDependency);
205 HLO_MATCHER(AfterAll);
206 HLO_MATCHER(AllGather);
207 HLO_MATCHER(AllReduce);
208 HLO_MATCHER(AllToAll);
209 HLO_MATCHER(And);
210 HLO_MATCHER(BatchNormGrad);
211 HLO_MATCHER(Bitcast);
212 HLO_MATCHER(Broadcast);
213 HLO_MATCHER(Call);
214 HLO_MATCHER(Ceil);
215 HLO_MATCHER(Clamp);
216 HLO_MATCHER(CollectivePermute);
217 HLO_MATCHER(Compare);
218 HLO_MATCHER(Concatenate);
219 HLO_MATCHER(Conditional);
220 HLO_MATCHER(Constant);
221 HLO_MATCHER(Convert);
222 HLO_MATCHER(Convolution);
223 HLO_MATCHER(Copy);
224 HLO_MATCHER(CopyDone);
225 HLO_MATCHER(CopyStart);
226 HLO_MATCHER(Divide);
227 HLO_MATCHER(Domain);
228 HLO_MATCHER(DynamicSlice);
229 HLO_MATCHER(DynamicUpdateSlice);
230 HLO_MATCHER(Exp);
231 HLO_MATCHER(Fft);
232 HLO_MATCHER(Floor);
233 HLO_MATCHER(Fusion);
234 HLO_MATCHER(Gather);
235 HLO_MATCHER(GetDimensionSize);
236 HLO_MATCHER(Infeed);
237 HLO_MATCHER(Iota);
238 HLO_MATCHER(IsFinite);
239 HLO_MATCHER(Log);
240 HLO_MATCHER(Map);
241 HLO_MATCHER(Maximum);
242 HLO_MATCHER(Minimum);
243 HLO_MATCHER(Multiply);
244 HLO_MATCHER(Negate);
245 HLO_MATCHER(Not);
246 HLO_MATCHER(Or);
247 HLO_MATCHER(Outfeed);
248 HLO_MATCHER(Pad);
249 HLO_MATCHER(PartitionId);
250 HLO_MATCHER(Power);
251 HLO_MATCHER(Recv);
252 HLO_MATCHER(RecvDone);
253 HLO_MATCHER(Reduce);
254 HLO_MATCHER(ReducePrecision);
255 HLO_MATCHER(ReduceWindow);
256 HLO_MATCHER(Remainder);
257 HLO_MATCHER(ReplicaId);
258 HLO_MATCHER(Reshape);
259 HLO_MATCHER(Reverse);
260 HLO_MATCHER(Rng);
261 HLO_MATCHER(Scatter);
262 HLO_MATCHER(Select);
263 HLO_MATCHER(SelectAndScatter);
264 HLO_MATCHER(Send);
265 HLO_MATCHER(SendDone);
266 HLO_MATCHER(SetDimensionSize);
267 HLO_MATCHER(ShiftLeft);
268 HLO_MATCHER(ShiftRightArithmetic);
269 HLO_MATCHER(ShiftRightLogical);
270 HLO_MATCHER(Sign);
271 HLO_MATCHER(Slice);
272 HLO_MATCHER(Sort);
273 HLO_MATCHER(Subtract);
274 HLO_MATCHER(Tanh);
275 HLO_MATCHER(Trace);
276 HLO_MATCHER(Transpose);
277 HLO_MATCHER(Tuple);
278 HLO_MATCHER(TupleSelect);
279 HLO_MATCHER(While);
280 HLO_MATCHER(Xor);
281 
282 // The special cases below let you check additional information about the
283 // HloInstruction, beyond just its opcode and operands.  In all cases you can
284 // still use the generic matcher which doesn't check this info.
285 //
286 // Feel free to add additional custom matchers below.
287 
288 //  - Parameter(N) matches parameter number N.
289 //  - Parameter() matches any parameter.
Parameter(int64 parameter_number)290 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
291     int64 parameter_number) {
292   return ::testing::MakeMatcher(
293       new ::xla::testing::HloParameterMatcher(parameter_number));
294 }
Parameter()295 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
296   return ::testing::MakeMatcher(
297       new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
298 }
299 
300 // Comparison matchers below do not require any additional arguments.
301 template <typename... M>
Eq(M...operands)302 inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
303   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
304       ComparisonDirection::kEq, {operands...}));
305 }
306 template <typename... M>
Ne(M...operands)307 inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
308   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
309       ComparisonDirection::kNe, {operands...}));
310 }
311 template <typename... M>
Ge(M...operands)312 inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
313   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
314       ComparisonDirection::kGe, {operands...}));
315 }
316 template <typename... M>
Gt(M...operands)317 inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
318   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
319       ComparisonDirection::kGt, {operands...}));
320 }
321 template <typename... M>
Le(M...operands)322 inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
323   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
324       ComparisonDirection::kLe, {operands...}));
325 }
326 template <typename... M>
Lt(M...operands)327 inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
328   return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
329       ComparisonDirection::kLt, {operands...}));
330 }
331 
332 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
333 // tuple element of operand, while GetTupleElement(operand) matches any GTE
334 // operation on operand, and GetTupleElement() matches any GTE operation at all.
GetTupleElement(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)335 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
336     ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
337   return ::testing::MakeMatcher(
338       new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
339 }
GetTupleElement(::testing::Matcher<const HloInstruction * > operand)340 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
341     ::testing::Matcher<const HloInstruction*> operand) {
342   return ::testing::MakeMatcher(
343       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
344 }
GetTupleElement()345 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
346   return ::testing::MakeMatcher(
347       new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
348 }
349 
350 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
351 //   target T and the given operands.
352 //
353 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
354 //   given operands.
355 //
356 // - CustomCall() matches any CustomCall HLO at all.
357 template <typename... M>
CustomCall(::testing::Matcher<string> call_target_matcher,M...operands)358 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
359     ::testing::Matcher<string> call_target_matcher, M... operands) {
360   return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
361       call_target_matcher, {operands...}));
362 }
363 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
364 // ::testing::Matcher<string>.  In that case, we want to prefer the overload
365 // above.
366 template <typename FirstM, typename... M,
367           typename Dummy = typename std::enable_if<
368               !std::is_convertible<FirstM, ::testing::Matcher<string>>::value,
369               void>::type*>
CustomCall(FirstM operands_first,M...operands_rest)370 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
371     FirstM operands_first, M... operands_rest) {
372   return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
373       HloOpcode::kCustomCall, {operands_first, operands_rest...}));
374 }
CustomCall()375 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
376   return ::testing::MakeMatcher(
377       new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
378 }
379 
380 // Verifies the shape or the shape and the layout of an HLO instruction against
381 // the provided shape object.
Shape(const class Shape & shape)382 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
383     const class Shape& shape) {
384   return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
385 }
Shape(absl::string_view shape)386 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
387     absl::string_view shape) {
388   return ::testing::MakeMatcher(
389       new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie()));
390 }
ShapeWithLayout(const class Shape & shape)391 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
392     const class Shape& shape) {
393   return ::testing::MakeMatcher(
394       new ::xla::testing::HloShapeAndLayoutMatcher(shape));
395 }
ShapeWithLayout(absl::string_view shape)396 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
397     absl::string_view shape) {
398   return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
399       ParseShape(shape).ValueOrDie()));
400 }
401 
402 // Verifies the value of the HloSharing against the provided sharding object.
Sharding(const HloSharding & sharding)403 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
404     const HloSharding& sharding) {
405   return ::testing::MakeMatcher(
406       new ::xla::testing::HloShardingMatcher(sharding));
407 }
408 // Matcher for Sharding from sharding string
Sharding(absl::string_view sharding)409 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
410     absl::string_view sharding) {
411   return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
412       ParseSharding(sharding).ValueOrDie()));
413 }
414 // Verifies that no HloSharding is set for an HLO instruction.
NoSharding()415 inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
416   return ::testing::MakeMatcher(
417       new ::xla::testing::HloShardingMatcher(absl::nullopt));
418 }
419 
Dot()420 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot() {
421   return ::testing::MakeMatcher(
422       new ::xla::testing::HloMatcher(::xla::HloOpcode::kDot, {}));
423 }
424 
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher)425 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
426     ::testing::Matcher<const HloInstruction*> lhs_matcher,
427     ::testing::Matcher<const HloInstruction*> rhs_matcher) {
428   return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
429       ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
430 }
431 
432 // Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
433 // equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
434 // equal to `rhs_contracting_dim`.
435 //
436 // Currently the HLO verifier rejects Dot operations with more than one
437 // contracting dimension (even though we can represent these in the
438 // DotDimensionNumbers proto) so there is no need to generalize this to support
439 // multiple contracting dimensions.
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher,int64 lhs_contracting_dim,int64 rhs_contracting_dim)440 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
441     ::testing::Matcher<const HloInstruction*> lhs_matcher,
442     ::testing::Matcher<const HloInstruction*> rhs_matcher,
443     int64 lhs_contracting_dim, int64 rhs_contracting_dim) {
444   return ::testing::MakeMatcher(
445       new ::xla::testing::HloDotWithContractingDimsMatcher(
446           lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
447 }
448 
449 // Matcher for asynchronous copies from one memory space to another. Implies
450 // CopyDone(CopyStart(...)) where from_space and to_space is the source and
451 // destination memory spaces, respectively.
AsyncCopy(int64 to_space,int64 from_space,::testing::Matcher<const HloInstruction * > operand_matcher)452 inline ::testing::Matcher<const ::xla::HloInstruction*> AsyncCopy(
453     int64 to_space, int64 from_space,
454     ::testing::Matcher<const HloInstruction*> operand_matcher) {
455   return ::testing::MakeMatcher(new ::xla::testing::HloAsyncCopyMatcher(
456       to_space, from_space, operand_matcher));
457 }
458 
459 #undef HLO_MATCHER
460 }  // namespace opcode_matchers
461 
462 // Helper to convert smart to raw pointers for matching.
463 template <typename Container>
Pointers(const Container & container)464 std::vector<const HloInstruction*> Pointers(const Container& container) {
465   std::vector<const HloInstruction*> result;
466   result.reserve(container.size());
467   for (const auto& entry : container) result.push_back(entry.get());
468   return result;
469 }
470 
471 }  // namespace testing
472 
473 // Tell GMock to print HloInstruction* by value, so error messages are nice.
474 // Has to be in the same namespace as 'HloInstruction'.
475 void PrintTo(const HloInstruction* inst, ::std::ostream* os);
476 
477 }  // namespace xla
478 
479 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
480