1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
18
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/test.h"
23
24 namespace xla {
25 namespace testing {
26
27 class HloMatcher : public ::testing::MatcherInterface<const HloInstruction*> {
28 public:
HloMatcher(HloOpcode opcode,std::vector<::testing::Matcher<const HloInstruction * >> operands)29 HloMatcher(HloOpcode opcode,
30 std::vector<::testing::Matcher<const HloInstruction*>> operands)
31 : opcode_(opcode), operands_(operands) {}
32
33 bool MatchAndExplain(const HloInstruction* instruction,
34 ::testing::MatchResultListener* listener) const override;
35
36 void DescribeTo(::std::ostream* os) const override;
37
38 private:
39 HloOpcode opcode_;
40 std::vector<::testing::Matcher<const HloInstruction*>> operands_;
41 };
42
43 // Custom matcher for parameters, which accepts a parameter number.
44 class HloParameterMatcher : public HloMatcher {
45 public:
HloParameterMatcher(int64 parameter_number)46 explicit HloParameterMatcher(int64 parameter_number)
47 : HloMatcher(HloOpcode::kParameter, /*operands=*/{}),
48 parameter_number_(parameter_number) {}
49
50 bool MatchAndExplain(const HloInstruction* instruction,
51 ::testing::MatchResultListener* listener) const override;
52
53 private:
54 int64 parameter_number_;
55 };
56
57 // Custom matcher for comparisons, which accepts a comparison direction.
58 class HloComparisonMatcher : public HloMatcher {
59 public:
HloComparisonMatcher(ComparisonDirection direction,std::vector<::testing::Matcher<const HloInstruction * >> operands)60 explicit HloComparisonMatcher(
61 ComparisonDirection direction,
62 std::vector<::testing::Matcher<const HloInstruction*>> operands)
63 : HloMatcher(HloOpcode::kCompare, operands), direction_(direction) {}
64
65 bool MatchAndExplain(const HloInstruction* instruction,
66 ::testing::MatchResultListener* listener) const override;
67
68 private:
69 ComparisonDirection direction_;
70 };
71
72 // Custom matcher for get-tuple-element instructions, which accepts a tuple
73 // index to match.
74 class HloGetTupleElementMatcher : public HloMatcher {
75 public:
HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)76 HloGetTupleElementMatcher(::testing::Matcher<const HloInstruction*> operand,
77 int64 tuple_index)
78 : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}),
79 tuple_index_(tuple_index) {}
80
81 bool MatchAndExplain(const HloInstruction* instruction,
82 ::testing::MatchResultListener* listener) const override;
83
84 private:
85 int64 tuple_index_;
86 };
87
88 // Custom matcher for custom-call instructions, which accepts a matcher for its
89 // call target.
90 class HloCustomCallMatcher : public HloMatcher {
91 public:
HloCustomCallMatcher(::testing::Matcher<string> call_target_matcher,std::vector<::testing::Matcher<const HloInstruction * >> operands)92 HloCustomCallMatcher(
93 ::testing::Matcher<string> call_target_matcher,
94 std::vector<::testing::Matcher<const HloInstruction*>> operands)
95 : HloMatcher(HloOpcode::kCustomCall, operands),
96 call_target_matcher_(call_target_matcher) {}
97
98 bool MatchAndExplain(const HloInstruction* instruction,
99 ::testing::MatchResultListener* listener) const override;
100 void DescribeTo(std::ostream* os) const override;
101
102 private:
103 ::testing::Matcher<string> call_target_matcher_;
104 };
105
106 class HloShapeMatcher
107 : public ::testing::MatcherInterface<const HloInstruction*> {
108 public:
HloShapeMatcher(const Shape & shape)109 explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
110
111 bool MatchAndExplain(const HloInstruction* instruction,
112 ::testing::MatchResultListener* listener) const override;
113 void DescribeTo(std::ostream* os) const override;
114
115 private:
116 Shape shape_;
117 };
118
119 class HloShapeAndLayoutMatcher
120 : public ::testing::MatcherInterface<const HloInstruction*> {
121 public:
HloShapeAndLayoutMatcher(const Shape & shape)122 explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {}
123
124 bool MatchAndExplain(const HloInstruction* instruction,
125 ::testing::MatchResultListener* listener) const override;
126 void DescribeTo(std::ostream* os) const override;
127
128 private:
129 Shape shape_;
130 };
131
132 // Verify the sharding of an instruction against the provided HloSharding. If a
133 // nullopt is provided for the expected sharding then it checks that no sharding
134 // is present for an instruction.
135 class HloShardingMatcher
136 : public ::testing::MatcherInterface<const HloInstruction*> {
137 public:
HloShardingMatcher(const absl::optional<HloSharding> & sharding)138 explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding)
139 : sharding_(sharding) {}
140
141 bool MatchAndExplain(const HloInstruction* instruction,
142 ::testing::MatchResultListener* listener) const override;
143 void DescribeTo(std::ostream* os) const override;
144
145 private:
146 absl::optional<HloSharding> sharding_;
147 };
148
149 // Matches a Dot HLO instruction with specific LHS and RHS contracting
150 // dimensions.
151 class HloDotWithContractingDimsMatcher : public HloMatcher {
152 public:
HloDotWithContractingDimsMatcher(::testing::Matcher<const HloInstruction * > lhs,::testing::Matcher<const HloInstruction * > rhs,int64 lhs_contracting_dim,int64 rhs_contracting_dim)153 explicit HloDotWithContractingDimsMatcher(
154 ::testing::Matcher<const HloInstruction*> lhs,
155 ::testing::Matcher<const HloInstruction*> rhs, int64 lhs_contracting_dim,
156 int64 rhs_contracting_dim)
157 : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}),
158 lhs_contracting_dim_(lhs_contracting_dim),
159 rhs_contracting_dim_(rhs_contracting_dim) {}
160
161 bool MatchAndExplain(const HloInstruction* instruction,
162 ::testing::MatchResultListener* listener) const override;
163 void DescribeTo(std::ostream* os) const override;
164
165 private:
166 int64 lhs_contracting_dim_;
167 int64 rhs_contracting_dim_;
168 };
169
170 // HloInstruction* matchers for opcode and operands. Example:
171 // namespace op = xla::opcode_matchers;
172 // EXPECT_THAT(instruction,
173 // op::Add(op::Reshape(), op::Add(op::Reshape(), _)));
174 namespace opcode_matchers {
175 #define HLO_MATCHER(opcode) \
176 template <typename... M> \
177 ::testing::Matcher<const ::xla::HloInstruction*> opcode(M... operands) { \
178 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( \
179 ::xla::HloOpcode::k##opcode, {operands...})); \
180 }
181 HLO_MATCHER(Abs);
182 HLO_MATCHER(Add);
183 HLO_MATCHER(AllToAll);
184 HLO_MATCHER(Bitcast);
185 HLO_MATCHER(Broadcast);
186 HLO_MATCHER(BatchNormGrad);
187 HLO_MATCHER(Call);
188 HLO_MATCHER(Ceil);
189 HLO_MATCHER(Clamp);
190 HLO_MATCHER(Compare);
191 HLO_MATCHER(Concatenate);
192 HLO_MATCHER(Conditional);
193 HLO_MATCHER(Constant);
194 HLO_MATCHER(Convert);
195 HLO_MATCHER(Convolution);
196 HLO_MATCHER(Copy);
197 HLO_MATCHER(AllReduce);
198 HLO_MATCHER(CollectivePermute);
199 HLO_MATCHER(Divide);
200 HLO_MATCHER(Domain);
201 HLO_MATCHER(DynamicSlice);
202 HLO_MATCHER(DynamicUpdateSlice);
203 HLO_MATCHER(Exp);
204 HLO_MATCHER(Floor);
205 HLO_MATCHER(Fusion);
206 HLO_MATCHER(AfterAll);
207 HLO_MATCHER(Iota);
208 HLO_MATCHER(Infeed);
209 HLO_MATCHER(IsFinite);
210 HLO_MATCHER(Log);
211 HLO_MATCHER(And);
212 HLO_MATCHER(Not);
213 HLO_MATCHER(Or);
214 HLO_MATCHER(Xor);
215 HLO_MATCHER(Map);
216 HLO_MATCHER(Maximum);
217 HLO_MATCHER(Minimum);
218 HLO_MATCHER(Multiply);
219 HLO_MATCHER(Negate);
220 HLO_MATCHER(Outfeed);
221 HLO_MATCHER(Pad);
222 HLO_MATCHER(Power);
223 HLO_MATCHER(Recv);
224 HLO_MATCHER(RecvDone);
225 HLO_MATCHER(Reduce);
226 HLO_MATCHER(ReducePrecision);
227 HLO_MATCHER(ReduceWindow);
228 HLO_MATCHER(Remainder);
229 HLO_MATCHER(Reshape);
230 HLO_MATCHER(Reverse);
231 HLO_MATCHER(Rng);
232 HLO_MATCHER(Scatter);
233 HLO_MATCHER(Select);
234 HLO_MATCHER(SelectAndScatter);
235 HLO_MATCHER(Send);
236 HLO_MATCHER(SendDone);
237 HLO_MATCHER(ShiftLeft);
238 HLO_MATCHER(ShiftRightLogical);
239 HLO_MATCHER(ShiftRightArithmetic);
240 HLO_MATCHER(Sign);
241 HLO_MATCHER(Slice);
242 HLO_MATCHER(Sort);
243 HLO_MATCHER(Subtract);
244 HLO_MATCHER(Tanh);
245 HLO_MATCHER(Trace);
246 HLO_MATCHER(Transpose);
247 HLO_MATCHER(Tuple);
248 HLO_MATCHER(TupleSelect);
249 HLO_MATCHER(While);
250
251 // The special cases below let you check additional information about the
252 // HloInstruction, beyond just its opcode and operands. In all cases you can
253 // still use the generic matcher which doesn't check this info.
254 //
255 // Feel free to add additional custom matchers below.
256
257 // - Parameter(N) matches parameter number N.
258 // - Parameter() matches any parameter.
Parameter(int64 parameter_number)259 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter(
260 int64 parameter_number) {
261 return ::testing::MakeMatcher(
262 new ::xla::testing::HloParameterMatcher(parameter_number));
263 }
Parameter()264 inline ::testing::Matcher<const ::xla::HloInstruction*> Parameter() {
265 return ::testing::MakeMatcher(
266 new ::xla::testing::HloMatcher(HloOpcode::kParameter, {}));
267 }
268
269 // Comparison matchers below do not require any additional arguments.
270 template <typename... M>
Eq(M...operands)271 inline ::testing::Matcher<const ::xla::HloInstruction*> Eq(M... operands) {
272 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
273 ComparisonDirection::kEq, {operands...}));
274 }
275 template <typename... M>
Ne(M...operands)276 inline ::testing::Matcher<const ::xla::HloInstruction*> Ne(M... operands) {
277 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
278 ComparisonDirection::kNe, {operands...}));
279 }
280 template <typename... M>
Ge(M...operands)281 inline ::testing::Matcher<const ::xla::HloInstruction*> Ge(M... operands) {
282 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
283 ComparisonDirection::kGe, {operands...}));
284 }
285 template <typename... M>
Gt(M...operands)286 inline ::testing::Matcher<const ::xla::HloInstruction*> Gt(M... operands) {
287 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
288 ComparisonDirection::kGt, {operands...}));
289 }
290 template <typename... M>
Le(M...operands)291 inline ::testing::Matcher<const ::xla::HloInstruction*> Le(M... operands) {
292 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
293 ComparisonDirection::kLe, {operands...}));
294 }
295 template <typename... M>
Lt(M...operands)296 inline ::testing::Matcher<const ::xla::HloInstruction*> Lt(M... operands) {
297 return ::testing::MakeMatcher(new ::xla::testing::HloComparisonMatcher(
298 ComparisonDirection::kLt, {operands...}));
299 }
300
301 // GetTupleElement(operand, N) matches a GTE instruction which gets the N'th
302 // tuple element of operand, while GetTupleElement(operand) matches any GTE
303 // operation on operand, and GetTupleElement() matches any GTE operation at all.
GetTupleElement(::testing::Matcher<const HloInstruction * > operand,int64 tuple_index)304 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
305 ::testing::Matcher<const HloInstruction*> operand, int64 tuple_index) {
306 return ::testing::MakeMatcher(
307 new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index));
308 }
GetTupleElement(::testing::Matcher<const HloInstruction * > operand)309 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement(
310 ::testing::Matcher<const HloInstruction*> operand) {
311 return ::testing::MakeMatcher(
312 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand}));
313 }
GetTupleElement()314 inline ::testing::Matcher<const ::xla::HloInstruction*> GetTupleElement() {
315 return ::testing::MakeMatcher(
316 new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {}));
317 }
318
319 // - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call
320 // target T and the given operands.
321 //
322 // - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the
323 // given operands.
324 //
325 // - CustomCall() matches any CustomCall HLO at all.
326 template <typename... M>
CustomCall(::testing::Matcher<string> call_target_matcher,M...operands)327 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
328 ::testing::Matcher<string> call_target_matcher, M... operands) {
329 return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher(
330 call_target_matcher, {operands...}));
331 }
332 // This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to
333 // ::testing::Matcher<string>. In that case, we want to prefer the overload
334 // above.
335 template <typename FirstM, typename... M,
336 typename Dummy = typename std::enable_if<
337 !std::is_convertible<FirstM, ::testing::Matcher<string>>::value,
338 void>::type*>
CustomCall(FirstM operands_first,M...operands_rest)339 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall(
340 FirstM operands_first, M... operands_rest) {
341 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
342 HloOpcode::kCustomCall, {operands_first, operands_rest...}));
343 }
CustomCall()344 inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
345 return ::testing::MakeMatcher(
346 new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
347 }
348
349 // Verifies the shape or the shape and the layout of an HLO instruction against
350 // the provided shape object.
Shape(const class Shape & shape)351 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
352 const class Shape& shape) {
353 return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
354 }
Shape(absl::string_view shape)355 inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
356 absl::string_view shape) {
357 return ::testing::MakeMatcher(
358 new ::xla::testing::HloShapeMatcher(ParseShape(shape).ValueOrDie()));
359 }
ShapeWithLayout(const class Shape & shape)360 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
361 const class Shape& shape) {
362 return ::testing::MakeMatcher(
363 new ::xla::testing::HloShapeAndLayoutMatcher(shape));
364 }
ShapeWithLayout(absl::string_view shape)365 inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
366 absl::string_view shape) {
367 return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher(
368 ParseShape(shape).ValueOrDie()));
369 }
370
371 // Verifies the value of the HloSharing against the provided sharding object.
Sharding(const HloSharding & sharding)372 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
373 const HloSharding& sharding) {
374 return ::testing::MakeMatcher(
375 new ::xla::testing::HloShardingMatcher(sharding));
376 }
377 // Matcher for Sharding from sharding string
Sharding(absl::string_view sharding)378 inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
379 absl::string_view sharding) {
380 return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
381 ParseSharding(sharding).ValueOrDie()));
382 }
383 // Verifies that no HloSharding is set for an HLO instruction.
NoSharding()384 inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
385 return ::testing::MakeMatcher(
386 new ::xla::testing::HloShardingMatcher(absl::nullopt));
387 }
388
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher)389 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
390 ::testing::Matcher<const HloInstruction*> lhs_matcher,
391 ::testing::Matcher<const HloInstruction*> rhs_matcher) {
392 return ::testing::MakeMatcher(new ::xla::testing::HloMatcher(
393 ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher}));
394 }
395
396 // Matches a Dot HLO instruction if it has exactly one lhs contracting dimension
397 // equal to `lhs_contracting_dim` and exactly one rhs contracting dimension
398 // equal to `rhs_contracting_dim`.
399 //
400 // Currently the HLO verifier rejects Dot operations with more than one
401 // contracting dimension (even though we can represent these in the
402 // DotDimensionNumbers proto) so there is no need to generalize this to support
403 // multiple contracting dimensions.
Dot(::testing::Matcher<const HloInstruction * > lhs_matcher,::testing::Matcher<const HloInstruction * > rhs_matcher,int64 lhs_contracting_dim,int64 rhs_contracting_dim)404 inline ::testing::Matcher<const ::xla::HloInstruction*> Dot(
405 ::testing::Matcher<const HloInstruction*> lhs_matcher,
406 ::testing::Matcher<const HloInstruction*> rhs_matcher,
407 int64 lhs_contracting_dim, int64 rhs_contracting_dim) {
408 return ::testing::MakeMatcher(
409 new ::xla::testing::HloDotWithContractingDimsMatcher(
410 lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim));
411 }
412
413 #undef HLO_MATCHER
414 } // namespace opcode_matchers
415
416 // Helper to convert smart to raw pointers for matching.
417 template <typename Container>
Pointers(const Container & container)418 std::vector<const HloInstruction*> Pointers(const Container& container) {
419 std::vector<const HloInstruction*> result;
420 result.reserve(container.size());
421 for (const auto& entry : container) result.push_back(entry.get());
422 return result;
423 }
424
425 } // namespace testing
426
427 // Tell GMock to print HloInstruction* by value, so error messages are nice.
428 // Has to be in the same namespace as 'HloInstruction'.
429 void PrintTo(const HloInstruction* inst, ::std::ostream* os);
430
431 } // namespace xla
432
433 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
434