1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
18
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/test.h"
23
24 namespace xla {
25 namespace testing {
26
27 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
28 public:
HloMatcher(HloOpcode opcode,std::vector<::testing::Matcher<const HloInstruction * >> operands)29 HloMatcher(HloOpcode opcode,
30 std::vector<::testing::Matcher<const HloInstruction*>> operands)
31 : opcode_(opcode), operands_(operands) {}
32
33 bool MatchAndExplain(const HloInstruction* instruction,
34 ::testing::MatchResultListener* listener) const override;
35
36 void DescribeTo(::std::ostream* os) const override;
37
38 private:
39 HloOpcode opcode_;
40 std::vector<::testing::Matcher<const HloInstruction*>> operands_;
41 };
42
43 // Custom matcher for parameters, which accepts a parameter number.
44 class HloParameterMatcher : public HloMatcher {
45 public:
HloParameterMatcher(int64 parameter_number)46 explicit HloParameterMatcher(int64 parameter_number)
47 : HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
48 parameter_number_(parameter_number) {}
49
50 bool MatchAndExplain(const HloInstruction* instruction,
51 ::testing::MatchResultListener* listener) const override;
52
53 private:
54 int64 parameter_number_;
55 };
56
57 // Custom matcher for comparisons, which accepts a comparison direction.
58 class HloComparisonMatcher : public HloMatcher {
59 public:
HloComparisonMatcher(ComparisonDirection direction,std::vector<::testing::Matcher<const HloInstruction * >> operands)60 explicit HloComparisonMatcher(
61 ComparisonDirection direction,
62 std::vector<::testing::Matcher<const HloInstruction*>> operands)
63 : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
64
65 bool MatchAndExplain(const HloInstruction* instruction,
66 ::testing::MatchResultListener* listener) const override;
67
68 private:
69 ComparisonDirection direction_;
70 };
71
72 // Custom matcher for get-tuple-element instructions, which accepts a tuple
73 // index to match.
74 class HloGetTupleElementMatcher : public HloMatcher {
75 public:
HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)76 HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
77 int64 tuple_index)
78 : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
79 tuple_index_(tuple_index) {}
80
81 bool MatchAndExplain(const HloInstruction* instruction,
82 ::testing::MatchResultListener* listener) const override;
83
84 private:
85 int64 tuple_index_;
86 };
87
88 // Custom matcher for custom-call instructions, which accepts a matcher for its
89 // call target.
90 class HloCustomCallMatcher : public HloMatcher {
91 public:
HloCustomCallMatcher(::testing::Matcher<string> call_target_matcher,std::vector<::testing::Matcher<const HloInstruction * >> operands)92 HloCustomCallMatcher(
93 ::testing::Matcher<string> call_target_matcher,
94 std::vector<::testing::Matcher<const HloInstruction*>> operands)
95 : HloMatcher(HloOpcode::kCustomCall, operands),
96 call_target_matcher_(call_target_matcher) {}
97
98 bool MatchAndExplain(const HloInstruction* instruction,
99 ::testing::MatchResultListener* listener) const override;
100 void DescribeTo(std::ostream* os) const override;
101
102 private:
103 ::testing::Matcher<string> call_target_matcher_;
104 };
105
106 class HloShapeMatcher
107 : public ::testing::MatcherInterface<const HloInstruction*> {
108 public:
HloShapeMatcher(const Shape & shape)109 explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
110
111 bool MatchAndExplain(const HloInstruction* instruction,
112 ::testing::MatchResultListener* listener) const override;
113 void DescribeTo(std::ostream* os) const override;
114
115 private:
116 Shape shape_;
117 };
118
119 class HloShapeAndLayoutMatcher
120 : public ::testing::MatcherInterface<const HloInstruction*> {
121 public:
HloShapeAndLayoutMatcher(const Shape & shape)122 explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {}
123
124 bool MatchAndExplain(const HloInstruction* instruction,
125 ::testing::MatchResultListener* listener) const override;
126 void DescribeTo(std::ostream* os) const override;
127
128 private:
129 Shape shape_;
130 };
131
132 // Verify the sharding of an instruction against the provided HloSharding. If a
133 // nullopt is provided for the expected sharding then it checks that no sharding
134 // is present for an instruction.
135 class HloShardingMatcher
136 : public ::testing::MatcherInterface<const HloInstruction*> {
137 public:
HloShardingMatcher(const absl::optional<HloSharding> & sharding)138 explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding)
139 : sharding_(sharding) {}
140
141 bool MatchAndExplain(const HloInstruction* instruction,
142 ::testing::MatchResultListener* listener) const override;
143 void DescribeTo(std::ostream* os) const override;
144
145 private:
146 absl::optional<HloSharding> sharding_;
147 };
148
149 // Matches a Dot HLO instruction with specific LHS and RHS contracting
150 // dimensions.
151 class HloDotWithContractingDimsMatcher : public HloMatcher {
152 public:
HloDotWithContractingDimsMatcher(::testing::Matcher<const HloInstruction * > lhs,::testing::Matcher<const HloInstruction * > rhs,int64 lhs_contracting_dim,int64 rhs_contracting_dim)153 explicit HloDotWithContractingDimsMatcher(
154 ::testing::Matcher<const HloInstruction*> lhs,
155 ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim,
156 int64 rhs_contracting_dim)
157 : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
158 lhs_contracting_dim_(lhs_contracting_dim),
159 rhs_contracting_dim_(rhs_contracting_dim) {}
160
161 bool MatchAndExplain(const HloInstruction* instruction,
162 ::testing::MatchResultListener* listener) const override;
163 void DescribeTo(std::ostream* os) const override;
164
165 private:
166 int64 lhs_contracting_dim_;
167 int64 rhs_contracting_dim_;
168 };
169
170 // Custom matcher for asynchronous copy (CopyStart/CopyDone pair) with specified
171 // source and destination memory spaces.
172 class HloAsyncCopyMatcher : public HloMatcher {
173 public:
HloAsyncCopyMatcher(int64 to_space,int64 from_space,::testing::Matcher<const HloInstruction * > operand)174 HloAsyncCopyMatcher(int64 to_space, int64 from_space,
175 ::testing::Matcher<const HloInstruction*> operand)
176 : HloMatcher(HloOpcode::kCopyDone,
177 {::testing::MakeMatcher(
178 new HloMatcher(HloOpcode::kCopyStart, {operand}))}),
179 to_space_(to_space),
180 from_space_(from_space) {}
181
182 bool MatchAndExplain(const HloInstruction* instruction,
183 ::testing::MatchResultListener* listener) const override;
184 void DescribeTo(std::ostream* os) const override;
185
186 private:
187 int64 to_space_;
188 int64 from_space_;
189 };
190
191 // HloInstruction* matchers for opcode and operands. Example:
192 // namespace op = xla::opcode_matchers;
193 // EXPECT_THAT(instruction,
194 // op::Add(op::Reshape(), op::Add(op::Reshape(), _)));
195 namespace opcode_matchers {
196 #define HLO_MATCHER(opcode) \
197 template <typename... M> \
198 ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \
199 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( \
200 ::xla::HloOpcode::k##opcode, {operands...})); \
201 }
202 HLO_MATCHER(Abs);
203 HLO_MATCHER(Add);
204 HLO_MATCHER(AddDependency);
205 HLO_MATCHER(AfterAll);
206 HLO_MATCHER(AllGather);
207 HLO_MATCHER(AllReduce);
208 HLO_MATCHER(AllToAll);
209 HLO_MATCHER(And);
210 HLO_MATCHER(BatchNormGrad);
211 HLO_MATCHER(Bitcast);
212 HLO_MATCHER(Broadcast);
213 HLO_MATCHER(Call);
214 HLO_MATCHER(Ceil);
215 HLO_MATCHER(Clamp);
216 HLO_MATCHER(CollectivePermute);
217 HLO_MATCHER(Compare);
218 HLO_MATCHER(Concatenate);
219 HLO_MATCHER(Conditional);
220 HLO_MATCHER(Constant);
221 HLO_MATCHER(Convert);
222 HLO_MATCHER(Convolution);
223 HLO_MATCHER(Copy);
224 HLO_MATCHER(CopyDone);
225 HLO_MATCHER(CopyStart);
226 HLO_MATCHER(Divide);
227 HLO_MATCHER(Domain);
228 HLO_MATCHER(DynamicSlice);
229 HLO_MATCHER(DynamicUpdateSlice);
230 HLO_MATCHER(Exp);
231 HLO_MATCHER(Fft);
232 HLO_MATCHER(Floor);
233 HLO_MATCHER(Fusion);
234 HLO_MATCHER(Gather);
235 HLO_MATCHER(GetDimensionSize);
236 HLO_MATCHER(Infeed);
237 HLO_MATCHER(Iota);
238 HLO_MATCHER(IsFinite);
239 HLO_MATCHER(Log);
240 HLO_MATCHER(Map);
241 HLO_MATCHER(Maximum);
242 HLO_MATCHER(Minimum);
243 HLO_MATCHER(Multiply);
244 HLO_MATCHER(Negate);
245 HLO_MATCHER(Not);
246 HLO_MATCHER(Or);
247 HLO_MATCHER(Outfeed);
248 HLO_MATCHER(Pad);
249 HLO_MATCHER(PartitionId);
250 HLO_MATCHER(Power);
251 HLO_MATCHER(Recv);
252 HLO_MATCHER(RecvDone);
253 HLO_MATCHER(Reduce);
254 HLO_MATCHER(ReducePrecision);
255 HLO_MATCHER(ReduceWindow);
256 HLO_MATCHER(Remainder);
257 HLO_MATCHER(ReplicaId);
258 HLO_MATCHER(Reshape);
259 HLO_MATCHER(Reverse);
260 HLO_MATCHER(Rng);
261 HLO_MATCHER(Scatter);
262 HLO_MATCHER(Select);
263 HLO_MATCHER(SelectAndScatter);
264 HLO_MATCHER(Send);
265 HLO_MATCHER(SendDone);
266 HLO_MATCHER(SetDimensionSize);
267 HLO_MATCHER(ShiftLeft);
268 HLO_MATCHER(ShiftRightArithmetic);
269 HLO_MATCHER(ShiftRightLogical);
270 HLO_MATCHER(Sign);
271 HLO_MATCHER(Slice);
272 HLO_MATCHER(Sort);
273 HLO_MATCHER(Subtract);
274 HLO_MATCHER(Tanh);
275 HLO_MATCHER(Trace);
276 HLO_MATCHER(Transpose);
277 HLO_MATCHER(Tuple);
278 HLO_MATCHER(TupleSelect);
279 HLO_MATCHER(While);
280 HLO_MATCHER(Xor);
281
282 // The special cases below let you check additional information about the
283 // HloInstruction, beyond just its opcode and operands. In all cases you can
284 // still use the generic matcher which doesn't check this info.
285 //
286 // Feel free to add additional custom matchers below.
287
288 // - Parameter(N) matches parameter number N.
289 // - Parameter() matches any parameter.
Parameter(int64 parameter_number)290 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
291 int64 parameter_number) {
292 return ::testing::MakeMatcher(
293 new ::xla::testing::HloParameterMatcher(parameter_number));
294 }
Parameter()295 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
296 return ::testing::MakeMatcher(
297 new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
298 }
299
300 // Comparison matchers below do not require any additional arguments.
301 template <typename... M>
Eq(M...operands)302 inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
303 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
304 ComparisonDirection::kEq, {operands...}));
305 }
306 template <typename... M>
Ne(M...operands)307 inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
308 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
309 ComparisonDirection::kNe, {operands...}));
310 }
311 template <typename... M>
Ge(M...operands)312 inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
313 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
314 ComparisonDirection::kGe, {operands...}));
315 }
316 template <typename... M>
Gt(M...operands)317 inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
318 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
319 ComparisonDirection::kGt, {operands...}));
320 }
321 template <typename... M>
Le(M...operands)322 inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
323 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
324 ComparisonDirection::kLe, {operands...}));
325 }
326 template <typename... M>
Lt(M...operands)327 inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
328 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
329 ComparisonDirection::kLt, {operands...}));
330 }
331
332 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
333 // tuple element of operand, while GetTupleElement(operand) matches any GTE
334 // operation on operand, and GetTupleElement() matches any GTE operation at all.
GetTupleElement(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)335 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
336 ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
337 return ::testing::MakeMatcher(
338 new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
339 }
GetTupleElement(::testing::Matcher<const HloInstruction * > operand)340 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
341 ::testing::Matcher<const HloInstruction*> operand) {
342 return ::testing::MakeMatcher(
343 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
344 }
GetTupleElement()345 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
346 return ::testing::MakeMatcher(
347 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
348 }
349
350 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
351 // target T and the given operands.
352 //
353 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
354 // given operands.
355 //
356 // - CustomCall() matches any CustomCall HLO at all.
357 template <typename... M>
CustomCall(::testing::Matcher<string> call_target_matcher,M...operands)358 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
359 ::testing::Matcher<string> call_target_matcher, M... operands) {
360 return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
361 call_target_matcher, {operands...}));
362 }
363 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
364 // ::testing::Matcher<string>. In that case, we want to prefer the overload
365 // above.
366 template <typename FirstM, typename... M,
367 typename Dummy = typename std::enable_if<
368 !std::is_convertible<FirstM, ::testing::Matcher<string>>::value,
369 void>::type*>
CustomCall(FirstM operands_first,M...operands_rest)370 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
371 FirstM operands_first, M... operands_rest) {
372 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
373 HloOpcode::kCustomCall, {operands_first, operands_rest...}));
374 }
CustomCall()375 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
376 return ::testing::MakeMatcher(
377 new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
378 }
379
380 // Verifies the shape or the shape and the layout of an HLO instruction against
381 // the provided shape object.
Shape(const class Shape & shape)382 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
383 const class Shape& shape) {
384 return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
385 }
Shape(absl::string_view shape)386 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
387 absl::string_view shape) {
388 return ::testing::MakeMatcher(
389 new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie()));
390 }
ShapeWithLayout(const class Shape & shape)391 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
392 const class Shape& shape) {
393 return ::testing::MakeMatcher(
394 new ::xla::testing::HloShapeAndLayoutMatcher(shape));
395 }
ShapeWithLayout(absl::string_view shape)396 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
397 absl::string_view shape) {
398 return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
399 ParseShape(shape).ValueOrDie()));
400 }
401
402 // Verifies the value of the HloSharing against the provided sharding object.
Sharding(const HloSharding & sharding)403 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
404 const HloSharding& sharding) {
405 return ::testing::MakeMatcher(
406 new ::xla::testing::HloShardingMatcher(sharding));
407 }
408 // Matcher for Sharding from sharding string
Sharding(absl::string_view sharding)409 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
410 absl::string_view sharding) {
411 return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
412 ParseSharding(sharding).ValueOrDie()));
413 }
414 // Verifies that no HloSharding is set for an HLO instruction.
NoSharding()415 inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
416 return ::testing::MakeMatcher(
417 new ::xla::testing::HloShardingMatcher(absl::nullopt));
418 }
419
Dot()420 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot() {
421 return ::testing::MakeMatcher(
422 new ::xla::testing::HloMatcher(::xla::HloOpcode::kDot, {}));
423 }
424
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher)425 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
426 ::testing::Matcher<const HloInstruction*> lhs_matcher,
427 ::testing::Matcher<const HloInstruction*> rhs_matcher) {
428 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
429 ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
430 }
431
432 // Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
433 // equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
434 // equal to `rhs_contracting_dim`.
435 //
436 // Currently the HLO verifier rejects Dot operations with more than one
437 // contracting dimension (even though we can represent these in the
438 // DotDimensionNumbers proto) so there is no need to generalize this to support
439 // multiple contracting dimensions.
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher,int64 lhs_contracting_dim,int64 rhs_contracting_dim)440 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
441 ::testing::Matcher<const HloInstruction*> lhs_matcher,
442 ::testing::Matcher<const HloInstruction*> rhs_matcher,
443 int64 lhs_contracting_dim, int64 rhs_contracting_dim) {
444 return ::testing::MakeMatcher(
445 new ::xla::testing::HloDotWithContractingDimsMatcher(
446 lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
447 }
448
449 // Matcher for asynchronous copies from one memory space to another. Implies
450 // CopyDone(CopyStart(...)) where from_space and to_space is the source and
451 // destination memory spaces, respectively.
AsyncCopy(int64 to_space,int64 from_space,::testing::Matcher<const HloInstruction * > operand_matcher)452 inline ::testing::Matcher<const ::xla::HloInstruction*> AsyncCopy(
453 int64 to_space, int64 from_space,
454 ::testing::Matcher<const HloInstruction*> operand_matcher) {
455 return ::testing::MakeMatcher(new ::xla::testing::HloAsyncCopyMatcher(
456 to_space, from_space, operand_matcher));
457 }
458
459 #undef HLO_MATCHER
460 } // namespace opcode_matchers
461
462 // Helper to convert smart to raw pointers for matching.
463 template <typename Container>
Pointers(const Container & container)464 std::vector<const HloInstruction*> Pointers(const Container& container) {
465 std::vector<const HloInstruction*> result;
466 result.reserve(container.size());
467 for (const auto& entry : container) result.push_back(entry.get());
468 return result;
469 }
470
471 } // namespace testing
472
473 // Tell GMock to print HloInstruction* by value, so error messages are nice.
474 // Has to be in the same namespace as 'HloInstruction'.
475 void PrintTo(const HloInstruction* inst, ::std::ostream* os);
476
477 } // namespace xla
478
479 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
480