• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_simplify_padding.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
22 #include "tensorflow/compiler/xla/service/call_inliner.h"
23 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
24 #include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
25 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
29 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/statusor.h"
35 
36 namespace xla::gpu {
37 namespace {
38 
39 namespace m = ::xla::match;
40 
41 class CudnnSimplifyPaddingTest : public HloTestBase {
42  protected:
43   // Runs the whole relevant pass pipeline starting at CudnnPadForConvolutions.
44   // This lets us test that we're matching the patterns that actually get
45   // generated by padding+vectorization.
RunEndToEnd(std::pair<int,int> compute_capability,HloModule * module)46   StatusOr<bool> RunEndToEnd(std::pair<int, int> compute_capability,
47                              HloModule* module) {
48     se::CudaComputeCapability cc{compute_capability.first,
49                                  compute_capability.second};
50 
51     TF_RETURN_IF_ERROR(
52         RunHloPass(CudnnPadForConvolutions(cc), module).status());
53 
54     TF_RETURN_IF_ERROR(
55         RunHloPass(CudnnVectorizeConvolutions(cc), module).status());
56     VLOG(1) << "after vectorizing convs:\n" << module->ToString();
57 
58     TF_RETURN_IF_ERROR(RunHloPass(CallInliner(), module).status());
59     VLOG(1) << "after inliner:\n" << module->ToString();
60 
61     TF_RETURN_IF_ERROR(RunHloPass(TupleSimplifier(), module).status());
62     VLOG(1) << "after tuple simplifier:\n" << module->ToString();
63 
64     TF_ASSIGN_OR_RETURN(bool changed,
65                         RunHloPass(CudnnSimplifyPadding(), module));
66     VLOG(1) << "after simplify_padding:\n" << module->ToString();
67 
68     TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
69                                       AlgebraicSimplifierOptions()),
70                                   module)
71                            .status());
72     VLOG(1) << "after algsimp:\n" << module->ToString();
73 
74     return changed;
75   }
76 
RunJustThisPass(HloModule * module)77   StatusOr<bool> RunJustThisPass(HloModule* module) {
78     TF_ASSIGN_OR_RETURN(bool changed,
79                         RunHloPass(CudnnSimplifyPadding(), module));
80     VLOG(1) << "after simplify_padding:\n" << module->ToString();
81 
82     // I know the name says "just this pass", but you really want algsimp too,
83     // otherwise the resulting patterns are ugly/hard to match.
84     TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
85                                       AlgebraicSimplifierOptions()),
86                                   module)
87                            .status());
88     return changed;
89   }
90 };
91 
ExpectOnlyPadsOneDim(int64_t dim,int64_t padding_high,const PaddingConfig & p)92 void ExpectOnlyPadsOneDim(int64_t dim, int64_t padding_high,
93                           const PaddingConfig& p) {
94   SCOPED_TRACE(p.DebugString());
95   for (int i = 0; i < p.dimensions_size(); ++i) {
96     SCOPED_TRACE(absl::StrCat("dimension ", i));
97     EXPECT_EQ(p.dimensions(i).edge_padding_low(), 0);
98     if (i == dim) {
99       EXPECT_EQ(p.dimensions(i).edge_padding_high(), padding_high);
100     } else {
101       EXPECT_EQ(p.dimensions(i).edge_padding_high(), 0);
102     }
103   }
104 }
105 
106 template <typename NativeT>
SetConstantValue(HloInstruction * instr,absl::FunctionRef<NativeT (absl::Span<const int64_t>,NativeT)> value_fn)107 void SetConstantValue(
108     HloInstruction* instr,
109     absl::FunctionRef<NativeT(absl::Span<const int64_t>, NativeT)> value_fn) {
110   Literal new_literal = instr->literal().Clone();
111   new_literal.MutableEachCell<int8_t>(value_fn);
112   TF_EXPECT_OK(instr->parent()->ReplaceWithNewInstruction(
113       instr, HloInstruction::CreateConstant(std::move(new_literal))));
114 }
115 
TEST_F(CudnnSimplifyPaddingTest,EndToEnd)116 TEST_F(CudnnSimplifyPaddingTest, EndToEnd) {
117   auto module = ParseAndReturnVerifiedModule(R"(
118   HloModule TestModule
119 
120   ENTRY TestComputation {
121     conv1 = (s8[10,20,30,190], u8[0]) custom-call(
122         s8[10,20,30,63] parameter(0), s8[3,5,63,190] parameter(1),
123         f32[10] parameter(2), s8[10,20,30,190] parameter(3)),
124       window={size=3x5}, dim_labels=b01f_01io->b01f,
125       custom_call_target="__cudnn$convBiasActivationForward"
126     conv1_result = get-tuple-element(conv1), index=0
127     ROOT conv2 = (s8[10,20,30,29], u8[0]) custom-call(
128         conv1_result, s8[3,5,190,29] parameter(4),
129         f32[10] parameter(5), s8[10,20,30,29] parameter(6)),
130       window={size=3x5}, dim_labels=b01f_01io->b01f,
131       custom_call_target="__cudnn$convBiasActivationForward"
132   })")
133                     .ValueOrDie();
134   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
135   EXPECT_TRUE(changed);
136 
137   SCOPED_TRACE(module->ToString());
138   auto* root = module->entry_computation()->root_instruction();
139 
140   // conv2 should be fed directly from conv1, without any intervening
141   // reshapes/pads.
142   EXPECT_THAT(
143       root, GmockMatch(m::Tuple(
144                 m::Slice(m::Reshape(m::GetTupleElement(m::CustomCall(
145                     "__cudnn$convBiasActivationForward",
146                     m::GetTupleElement(
147                         m::CustomCall("__cudnn$convBiasActivationForward"), 0),
148                     m::Op(), m::Op(), m::Op())))),
149                 m::Op())));
150 }
151 
TEST_F(CudnnSimplifyPaddingTest,PaddedWeights)152 TEST_F(CudnnSimplifyPaddingTest, PaddedWeights) {
153   auto module = ParseAndReturnVerifiedModule(R"(
154     HloModule TestModule
155 
156     ENTRY TestComputation {
157       weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
158       conv = (s8[10,10,10,10], u8[0]) custom-call(
159           s8[10,10,10,10] parameter(1),
160           weights
161         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
162         custom_call_target="__cudnn$convForward"
163       conv_result = get-tuple-element(conv), index=0
164       slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
165       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
166     }
167   )")
168                     .ValueOrDie();
169 
170   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
171   EXPECT_TRUE(changed);
172 
173   SCOPED_TRACE(module->ToString());
174   auto* root = module->entry_computation()->root_instruction();
175   const HloInstruction* pad = nullptr;
176   ASSERT_THAT(root,
177               GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
178                                 m::ConstantScalar(0))));
179 
180   ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
181 }
182 
183 // This is similar to PaddedWeights, except the only 3 elements of the weights
184 // are padded to 0 while we slice off 4 elements from the output features. As a
185 // result, not all of the sliced elements are 0, and we can't merge the slice
186 // into the pad that follows.
TEST_F(CudnnSimplifyPaddingTest,PaddedWeightsNotPaddedEnough)187 TEST_F(CudnnSimplifyPaddingTest, PaddedWeightsNotPaddedEnough) {
188   auto module = ParseAndReturnVerifiedModule(R"(
189     HloModule TestModule
190 
191     ENTRY TestComputation {
192       weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_3
193       conv = (s8[10,10,10,10], u8[0]) custom-call(
194           s8[10,10,10,10] parameter(1),
195           weights
196         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
197         custom_call_target="__cudnn$convForward"
198       conv_result = get-tuple-element(conv), index=0
199       slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
200       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
201     }
202   )")
203                     .ValueOrDie();
204 
205   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
206   EXPECT_FALSE(changed);
207 }
208 
TEST_F(CudnnSimplifyPaddingTest,PaddedAndReshapedWeightsNCHW)209 TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNCHW) {
210   auto module = ParseAndReturnVerifiedModule(R"(
211     HloModule TestModule
212 
213     ENTRY TestComputation {
214       weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
215       weights = s8[2,32,64,3,3] reshape(weights_p)
216       conv = (s8[10,2,32,10,10], u8[0]) custom-call(
217           s8[10,2,32,10,10] parameter(1),
218           weights
219         ), window={size=3x3}, dim_labels=bf?01_i?o01->bf?01,
220         custom_call_target="__cudnn$convForward"
221       conv_result = get-tuple-element(conv), index=0
222       slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
223       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_5x0_0x0_0
224     }
225   )")
226                     .ValueOrDie();
227 
228   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
229   EXPECT_TRUE(changed);
230 
231   SCOPED_TRACE(module->ToString());
232   auto* root = module->entry_computation()->root_instruction();
233   const HloInstruction* pad = nullptr;
234   ASSERT_THAT(
235       root, GmockMatch(
236                 m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
237                        m::ConstantScalar(0))));
238 
239   ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/1, pad->padding_config());
240 }
241 
TEST_F(CudnnSimplifyPaddingTest,PaddedAndReshapedWeightsNHWC)242 TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNHWC) {
243   auto module = ParseAndReturnVerifiedModule(R"(
244     HloModule TestModule
245 
246     ENTRY TestComputation {
247       weights_p = pad(s8[3,3,64,60] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
248       weights = s8[3,3,2,32,64] reshape(weights_p)
249       conv = (s8[10,10,10,2,32], u8[0]) custom-call(
250           s8[10,10,10,2,32] parameter(1),
251           weights
252         ), window={size=3x3}, dim_labels=b01f?_01i?o->b01f?,
253         custom_call_target="__cudnn$convForward"
254       conv_result = get-tuple-element(conv), index=0
255       slice = s8[10,10,10,60] slice(s8[10,10,10,64] reshape(conv_result)), slice={[0:10], [0:10], [0:10], [0:60]}
256       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
257     }
258   )")
259                     .ValueOrDie();
260 
261   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
262   EXPECT_TRUE(changed);
263 
264   SCOPED_TRACE(module->ToString());
265   auto* root = module->entry_computation()->root_instruction();
266   const HloInstruction* pad = nullptr;
267   ASSERT_THAT(
268       root, GmockMatch(
269                 m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
270                        m::ConstantScalar(0))));
271 
272   ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
273 }
274 
TEST_F(CudnnSimplifyPaddingTest,PaddedTransposedAndReshapedOutput)275 TEST_F(CudnnSimplifyPaddingTest, PaddedTransposedAndReshapedOutput) {
276   auto module = ParseAndReturnVerifiedModule(R"(
277     HloModule TestModule
278 
279     ENTRY TestComputation {
280       weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
281       weights = s8[2,32,64,3,3] reshape(weights_p)
282       conv = (s8[10,2,10,10,32], u8[0]) custom-call(
283           s8[10,2,10,10,32] parameter(1),
284           weights
285         ), window={size=3x3}, dim_labels=bf01?_i?o01->bf01?,
286         custom_call_target="__cudnn$convForward"
287       conv_result = get-tuple-element(conv), index=0
288       conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
289       slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
290       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
291     }
292   )")
293                     .ValueOrDie();
294 
295   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
296   EXPECT_TRUE(changed);
297 
298   SCOPED_TRACE(module->ToString());
299   auto* root = module->entry_computation()->root_instruction();
300   const HloInstruction* pad = nullptr;
301   ASSERT_THAT(
302       root,
303       GmockMatch(m::Pad(
304           &pad,
305           m::Reshape(m::Transpose(m::GetTupleElement(m::CustomCall(), 0))),
306           m::ConstantScalar(0))));
307 
308   ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/2, pad->padding_config());
309 }
310 
TEST_F(CudnnSimplifyPaddingTest,PaddedConstantWeight)311 TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeight) {
312   auto module = ParseAndReturnVerifiedModule(R"(
313     HloModule TestModule
314 
315     ENTRY TestComputation {
316       conv = (s8[10,10,10,10], u8[0]) custom-call(
317           s8[10,10,10,10] parameter(0),
318           s8[3,3,10,10] constant({...})
319         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
320         custom_call_target="__cudnn$convForward"
321       conv_result = get-tuple-element(conv), index=0
322       slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
323       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
324     }
325   )")
326                     .ValueOrDie();
327 
328   // Set the constant's value.  (The HLO text above sets it to all 0s.)
329   {
330     HloInstruction* weights = nullptr;
331     ASSERT_THAT(module->entry_computation()->root_instruction(),
332                 GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
333                                       m::Op(), m::Constant(&weights)))),
334                                   m::Op())));
335     SetConstantValue<int8_t>(
336         weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
337           if (dims[3] < 6) return 1;
338           return 0;
339         });
340   }
341 
342   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
343   EXPECT_TRUE(changed);
344 
345   SCOPED_TRACE(module->ToString());
346   auto* root = module->entry_computation()->root_instruction();
347   const HloInstruction* pad = nullptr;
348   ASSERT_THAT(root,
349               GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
350                                 m::ConstantScalar(0))));
351 
352   ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
353 }
354 
TEST_F(CudnnSimplifyPaddingTest,PaddedConstantWeightIsNotLargeEnough)355 TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeightIsNotLargeEnough) {
356   auto module = ParseAndReturnVerifiedModule(R"(
357     HloModule TestModule
358 
359     ENTRY TestComputation {
360       conv = (s8[10,10,10,10], u8[0]) custom-call(
361           s8[10,10,10,10] parameter(0),
362           s8[3,3,10,10] constant({...})
363         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
364         custom_call_target="__cudnn$convForward"
365       conv_result = get-tuple-element(conv), index=0
366       slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
367       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
368     }
369   )")
370                     .ValueOrDie();
371 
372   // Set the constant's value.  (The HLO text above sets it to all 0s.)
373   {
374     HloInstruction* weights = nullptr;
375     ASSERT_THAT(module->entry_computation()->root_instruction(),
376                 GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
377                                       m::Op(), m::Constant(&weights)))),
378                                   m::Op())));
379     SetConstantValue<int8_t>(
380         weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
381           // The sixth feature dimension (i.e. index 5) is only partially 0.
382           if (dims[3] < 5 /*|| (dims[3] == 5 && dims[2] > 1)*/) return 0;
383           return 1;
384         });
385   }
386 
387   // Some of the value sliced off are not 0, so we can't merge the slice into
388   // the pad.
389   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
390   EXPECT_FALSE(changed);
391 }
392 
TEST_F(CudnnSimplifyPaddingTest,ReshapeDoesntMergeVectCDim)393 TEST_F(CudnnSimplifyPaddingTest, ReshapeDoesntMergeVectCDim) {
394   auto module = ParseAndReturnVerifiedModule(R"(
395     HloModule TestModule
396 
397     ENTRY TestComputation {
398       weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
399       weights = s8[2,64,3,3,32] reshape(weights_p)
400       conv = (s8[10,2,10,10,32], u8[0]) custom-call(
401           s8[10,2,10,10,32] parameter(1),
402           weights_p
403         ), window={size=3x3}, dim_labels=bf01?_io01?->bf01?,
404         custom_call_target="__cudnn$convForward"
405       conv_result = get-tuple-element(conv), index=0
406       slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
407       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
408     }
409   )")
410                     .ValueOrDie();
411 
412   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
413   EXPECT_FALSE(changed);
414 }
415 
TEST_F(CudnnSimplifyPaddingTest,TwoVectCDimsInOutput)416 TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInOutput) {
417   auto module = ParseAndReturnVerifiedModule(R"(
418     HloModule TestModule
419 
420     ENTRY TestComputation {
421       weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
422       weights = s8[2,64,3,3,32] reshape(weights_p)
423       conv = (s8[10,2,10,10,4,8], u8[0]) custom-call(
424           s8[10,2,10,10,32] parameter(1),
425           weights
426         ), window={size=3x3}, dim_labels=bf01?_io01?->bf01??,
427         custom_call_target="__cudnn$convForward"
428       conv_result = get-tuple-element(conv), index=0
429       conv_transposed = s8[10,2,4,8,10,10] transpose(conv_result), dimensions={0,1,4,5,2,3}
430       slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
431       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
432     }
433   )")
434                     .ValueOrDie();
435 
436   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
437   EXPECT_FALSE(changed);
438 }
439 
TEST_F(CudnnSimplifyPaddingTest,TwoVectCDimsInKernel)440 TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInKernel) {
441   auto module = ParseAndReturnVerifiedModule(R"(
442     HloModule TestModule
443 
444     ENTRY TestComputation {
445       weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
446       weights = s8[2,64,3,3,4,8] reshape(weights_p)
447       conv = (s8[10,2,10,10,32], u8[0]) custom-call(
448           s8[10,2,10,10,32] parameter(1),
449           weights
450         ), window={size=3x3}, dim_labels=bf01?_io01??->bf01?,
451         custom_call_target="__cudnn$convForward"
452       conv_result = get-tuple-element(conv), index=0
453       conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
454       slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
455       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
456     }
457   )")
458                     .ValueOrDie();
459 
460   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
461   EXPECT_FALSE(changed);
462 }
463 
TEST_F(CudnnSimplifyPaddingTest,SliceDoesntStartAtBeginning)464 TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginning) {
465   auto module = ParseAndReturnVerifiedModule(R"(
466     HloModule TestModule
467 
468     ENTRY TestComputation {
469       weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
470       conv = (s8[10,10,10,10], u8[0]) custom-call(
471           s8[10,10,10,10] parameter(1),
472           weights
473         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
474         custom_call_target="__cudnn$convForward"
475       conv_result = get-tuple-element(conv), index=0
476       slice = s8[10,9,10,6] slice(conv_result), slice={[0:10], [1:10], [0:10], [0:6]}
477       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
478     }
479   )")
480                     .ValueOrDie();
481 
482   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
483   EXPECT_FALSE(changed);
484 }
485 
TEST_F(CudnnSimplifyPaddingTest,SliceDoesntStartAtBeginningOfFeatureDim)486 TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginningOfFeatureDim) {
487   auto module = ParseAndReturnVerifiedModule(R"(
488     HloModule TestModule
489 
490     ENTRY TestComputation {
491       weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
492       conv = (s8[10,10,10,10], u8[0]) custom-call(
493           s8[10,10,10,10] parameter(1),
494           weights
495         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
496         custom_call_target="__cudnn$convForward"
497       conv_result = get-tuple-element(conv), index=0
498       slice = s8[10,10,10,5] slice(conv_result), slice={[0:10], [0:10], [0:10], [1:6]}
499       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
500     }
501   )")
502                     .ValueOrDie();
503 
504   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
505   EXPECT_FALSE(changed);
506 }
507 
TEST_F(CudnnSimplifyPaddingTest,SliceHasStride)508 TEST_F(CudnnSimplifyPaddingTest, SliceHasStride) {
509   auto module = ParseAndReturnVerifiedModule(R"(
510     HloModule TestModule
511 
512     ENTRY TestComputation {
513       weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
514       conv = (s8[10,10,10,10], u8[0]) custom-call(
515           s8[10,10,10,10] parameter(1),
516           weights
517         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
518         custom_call_target="__cudnn$convForward"
519       conv_result = get-tuple-element(conv), index=0
520       slice = s8[10,10,10,3] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6:2]}
521       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
522     }
523   )")
524                     .ValueOrDie();
525 
526   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
527   EXPECT_FALSE(changed);
528 }
529 
TEST_F(CudnnSimplifyPaddingTest,PadAddsInteriorPadding)530 TEST_F(CudnnSimplifyPaddingTest, PadAddsInteriorPadding) {
531   auto module = ParseAndReturnVerifiedModule(R"(
532     HloModule TestModule
533 
534     ENTRY TestComputation {
535       weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
536       conv = (s8[10,10,10,10], u8[0]) custom-call(
537           s8[10,10,10,10] parameter(1),
538           weights
539         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
540         custom_call_target="__cudnn$convForward"
541       conv_result = get-tuple-element(conv), index=0
542       slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
543       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5_1
544     }
545   )")
546                     .ValueOrDie();
547 
548   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
549   EXPECT_FALSE(changed);
550 }
551 
TEST_F(CudnnSimplifyPaddingTest,SliceMoreElementsThanPad)552 TEST_F(CudnnSimplifyPaddingTest, SliceMoreElementsThanPad) {
553   auto module = ParseAndReturnVerifiedModule(R"(
554     HloModule TestModule
555 
556     ENTRY TestComputation {
557       weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
558       conv = (s8[10,10,10,10], u8[0]) custom-call(
559           s8[10,10,10,10] parameter(1),
560           weights
561         ), window={size=3x3}, dim_labels=b01f_01io->b01f,
562         custom_call_target="__cudnn$convForward"
563       conv_result = get-tuple-element(conv), index=0
564       slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
565       ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_2
566     }
567   )")
568                     .ValueOrDie();
569 
570   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
571   EXPECT_TRUE(changed);
572 
573   SCOPED_TRACE(module->ToString());
574   auto* root = module->entry_computation()->root_instruction();
575   const HloInstruction* slice = nullptr;
576   // The pass creates a pad with negative padding; this is simplified by algsimp
577   // into a slice.
578   ASSERT_THAT(root, GmockMatch(m::Slice(
579                         &slice, m::GetTupleElement(m::CustomCall(), 0))));
580   for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
581     SCOPED_TRACE(i);
582     EXPECT_EQ(slice->slice_starts(i), 0);
583     EXPECT_EQ(slice->slice_strides(i), 1);
584     if (i != 3) {
585       EXPECT_EQ(slice->slice_limits(i), 10);
586     } else {
587       EXPECT_EQ(slice->slice_limits(i), 8);
588     }
589   }
590 }
591 
592 }  // anonymous namespace
593 }  // namespace xla::gpu
594