1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_simplify_padding.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
22 #include "tensorflow/compiler/xla/service/call_inliner.h"
23 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
24 #include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
25 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
26 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
29 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/statusor.h"
35
36 namespace xla::gpu {
37 namespace {
38
39 namespace m = ::xla::match;
40
41 class CudnnSimplifyPaddingTest : public HloTestBase {
42 protected:
43 // Runs the whole relevant pass pipeline starting at CudnnPadForConvolutions.
44 // This lets us test that we're matching the patterns that actually get
45 // generated by padding+vectorization.
RunEndToEnd(std::pair<int,int> compute_capability,HloModule * module)46 StatusOr<bool> RunEndToEnd(std::pair<int, int> compute_capability,
47 HloModule* module) {
48 se::CudaComputeCapability cc{compute_capability.first,
49 compute_capability.second};
50
51 TF_RETURN_IF_ERROR(
52 RunHloPass(CudnnPadForConvolutions(cc), module).status());
53
54 TF_RETURN_IF_ERROR(
55 RunHloPass(CudnnVectorizeConvolutions(cc), module).status());
56 VLOG(1) << "after vectorizing convs:\n" << module->ToString();
57
58 TF_RETURN_IF_ERROR(RunHloPass(CallInliner(), module).status());
59 VLOG(1) << "after inliner:\n" << module->ToString();
60
61 TF_RETURN_IF_ERROR(RunHloPass(TupleSimplifier(), module).status());
62 VLOG(1) << "after tuple simplifier:\n" << module->ToString();
63
64 TF_ASSIGN_OR_RETURN(bool changed,
65 RunHloPass(CudnnSimplifyPadding(), module));
66 VLOG(1) << "after simplify_padding:\n" << module->ToString();
67
68 TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
69 AlgebraicSimplifierOptions()),
70 module)
71 .status());
72 VLOG(1) << "after algsimp:\n" << module->ToString();
73
74 return changed;
75 }
76
RunJustThisPass(HloModule * module)77 StatusOr<bool> RunJustThisPass(HloModule* module) {
78 TF_ASSIGN_OR_RETURN(bool changed,
79 RunHloPass(CudnnSimplifyPadding(), module));
80 VLOG(1) << "after simplify_padding:\n" << module->ToString();
81
82 // I know the name says "just this pass", but you really want algsimp too,
83 // otherwise the resulting patterns are ugly/hard to match.
84 TF_RETURN_IF_ERROR(RunHloPass(HloPassFix<AlgebraicSimplifier>(
85 AlgebraicSimplifierOptions()),
86 module)
87 .status());
88 return changed;
89 }
90 };
91
ExpectOnlyPadsOneDim(int64_t dim,int64_t padding_high,const PaddingConfig & p)92 void ExpectOnlyPadsOneDim(int64_t dim, int64_t padding_high,
93 const PaddingConfig& p) {
94 SCOPED_TRACE(p.DebugString());
95 for (int i = 0; i < p.dimensions_size(); ++i) {
96 SCOPED_TRACE(absl::StrCat("dimension ", i));
97 EXPECT_EQ(p.dimensions(i).edge_padding_low(), 0);
98 if (i == dim) {
99 EXPECT_EQ(p.dimensions(i).edge_padding_high(), padding_high);
100 } else {
101 EXPECT_EQ(p.dimensions(i).edge_padding_high(), 0);
102 }
103 }
104 }
105
106 template <typename NativeT>
SetConstantValue(HloInstruction * instr,absl::FunctionRef<NativeT (absl::Span<const int64_t>,NativeT)> value_fn)107 void SetConstantValue(
108 HloInstruction* instr,
109 absl::FunctionRef<NativeT(absl::Span<const int64_t>, NativeT)> value_fn) {
110 Literal new_literal = instr->literal().Clone();
111 new_literal.MutableEachCell<int8_t>(value_fn);
112 TF_EXPECT_OK(instr->parent()->ReplaceWithNewInstruction(
113 instr, HloInstruction::CreateConstant(std::move(new_literal))));
114 }
115
TEST_F(CudnnSimplifyPaddingTest,EndToEnd)116 TEST_F(CudnnSimplifyPaddingTest, EndToEnd) {
117 auto module = ParseAndReturnVerifiedModule(R"(
118 HloModule TestModule
119
120 ENTRY TestComputation {
121 conv1 = (s8[10,20,30,190], u8[0]) custom-call(
122 s8[10,20,30,63] parameter(0), s8[3,5,63,190] parameter(1),
123 f32[10] parameter(2), s8[10,20,30,190] parameter(3)),
124 window={size=3x5}, dim_labels=b01f_01io->b01f,
125 custom_call_target="__cudnn$convBiasActivationForward"
126 conv1_result = get-tuple-element(conv1), index=0
127 ROOT conv2 = (s8[10,20,30,29], u8[0]) custom-call(
128 conv1_result, s8[3,5,190,29] parameter(4),
129 f32[10] parameter(5), s8[10,20,30,29] parameter(6)),
130 window={size=3x5}, dim_labels=b01f_01io->b01f,
131 custom_call_target="__cudnn$convBiasActivationForward"
132 })")
133 .ValueOrDie();
134 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunEndToEnd({7, 5}, module.get()));
135 EXPECT_TRUE(changed);
136
137 SCOPED_TRACE(module->ToString());
138 auto* root = module->entry_computation()->root_instruction();
139
140 // conv2 should be fed directly from conv1, without any intervening
141 // reshapes/pads.
142 EXPECT_THAT(
143 root, GmockMatch(m::Tuple(
144 m::Slice(m::Reshape(m::GetTupleElement(m::CustomCall(
145 "__cudnn$convBiasActivationForward",
146 m::GetTupleElement(
147 m::CustomCall("__cudnn$convBiasActivationForward"), 0),
148 m::Op(), m::Op(), m::Op())))),
149 m::Op())));
150 }
151
TEST_F(CudnnSimplifyPaddingTest,PaddedWeights)152 TEST_F(CudnnSimplifyPaddingTest, PaddedWeights) {
153 auto module = ParseAndReturnVerifiedModule(R"(
154 HloModule TestModule
155
156 ENTRY TestComputation {
157 weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
158 conv = (s8[10,10,10,10], u8[0]) custom-call(
159 s8[10,10,10,10] parameter(1),
160 weights
161 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
162 custom_call_target="__cudnn$convForward"
163 conv_result = get-tuple-element(conv), index=0
164 slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
165 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
166 }
167 )")
168 .ValueOrDie();
169
170 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
171 EXPECT_TRUE(changed);
172
173 SCOPED_TRACE(module->ToString());
174 auto* root = module->entry_computation()->root_instruction();
175 const HloInstruction* pad = nullptr;
176 ASSERT_THAT(root,
177 GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
178 m::ConstantScalar(0))));
179
180 ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
181 }
182
183 // This is similar to PaddedWeights, except the only 3 elements of the weights
184 // are padded to 0 while we slice off 4 elements from the output features. As a
185 // result, not all of the sliced elements are 0, and we can't merge the slice
186 // into the pad that follows.
TEST_F(CudnnSimplifyPaddingTest,PaddedWeightsNotPaddedEnough)187 TEST_F(CudnnSimplifyPaddingTest, PaddedWeightsNotPaddedEnough) {
188 auto module = ParseAndReturnVerifiedModule(R"(
189 HloModule TestModule
190
191 ENTRY TestComputation {
192 weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_3
193 conv = (s8[10,10,10,10], u8[0]) custom-call(
194 s8[10,10,10,10] parameter(1),
195 weights
196 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
197 custom_call_target="__cudnn$convForward"
198 conv_result = get-tuple-element(conv), index=0
199 slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
200 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
201 }
202 )")
203 .ValueOrDie();
204
205 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
206 EXPECT_FALSE(changed);
207 }
208
TEST_F(CudnnSimplifyPaddingTest,PaddedAndReshapedWeightsNCHW)209 TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNCHW) {
210 auto module = ParseAndReturnVerifiedModule(R"(
211 HloModule TestModule
212
213 ENTRY TestComputation {
214 weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
215 weights = s8[2,32,64,3,3] reshape(weights_p)
216 conv = (s8[10,2,32,10,10], u8[0]) custom-call(
217 s8[10,2,32,10,10] parameter(1),
218 weights
219 ), window={size=3x3}, dim_labels=bf?01_i?o01->bf?01,
220 custom_call_target="__cudnn$convForward"
221 conv_result = get-tuple-element(conv), index=0
222 slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
223 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_5x0_0x0_0
224 }
225 )")
226 .ValueOrDie();
227
228 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
229 EXPECT_TRUE(changed);
230
231 SCOPED_TRACE(module->ToString());
232 auto* root = module->entry_computation()->root_instruction();
233 const HloInstruction* pad = nullptr;
234 ASSERT_THAT(
235 root, GmockMatch(
236 m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
237 m::ConstantScalar(0))));
238
239 ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/1, pad->padding_config());
240 }
241
TEST_F(CudnnSimplifyPaddingTest,PaddedAndReshapedWeightsNHWC)242 TEST_F(CudnnSimplifyPaddingTest, PaddedAndReshapedWeightsNHWC) {
243 auto module = ParseAndReturnVerifiedModule(R"(
244 HloModule TestModule
245
246 ENTRY TestComputation {
247 weights_p = pad(s8[3,3,64,60] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
248 weights = s8[3,3,2,32,64] reshape(weights_p)
249 conv = (s8[10,10,10,2,32], u8[0]) custom-call(
250 s8[10,10,10,2,32] parameter(1),
251 weights
252 ), window={size=3x3}, dim_labels=b01f?_01i?o->b01f?,
253 custom_call_target="__cudnn$convForward"
254 conv_result = get-tuple-element(conv), index=0
255 slice = s8[10,10,10,60] slice(s8[10,10,10,64] reshape(conv_result)), slice={[0:10], [0:10], [0:10], [0:60]}
256 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
257 }
258 )")
259 .ValueOrDie();
260
261 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
262 EXPECT_TRUE(changed);
263
264 SCOPED_TRACE(module->ToString());
265 auto* root = module->entry_computation()->root_instruction();
266 const HloInstruction* pad = nullptr;
267 ASSERT_THAT(
268 root, GmockMatch(
269 m::Pad(&pad, m::Reshape(m::GetTupleElement(m::CustomCall(), 0)),
270 m::ConstantScalar(0))));
271
272 ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
273 }
274
TEST_F(CudnnSimplifyPaddingTest,PaddedTransposedAndReshapedOutput)275 TEST_F(CudnnSimplifyPaddingTest, PaddedTransposedAndReshapedOutput) {
276 auto module = ParseAndReturnVerifiedModule(R"(
277 HloModule TestModule
278
279 ENTRY TestComputation {
280 weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
281 weights = s8[2,32,64,3,3] reshape(weights_p)
282 conv = (s8[10,2,10,10,32], u8[0]) custom-call(
283 s8[10,2,10,10,32] parameter(1),
284 weights
285 ), window={size=3x3}, dim_labels=bf01?_i?o01->bf01?,
286 custom_call_target="__cudnn$convForward"
287 conv_result = get-tuple-element(conv), index=0
288 conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
289 slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
290 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
291 }
292 )")
293 .ValueOrDie();
294
295 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
296 EXPECT_TRUE(changed);
297
298 SCOPED_TRACE(module->ToString());
299 auto* root = module->entry_computation()->root_instruction();
300 const HloInstruction* pad = nullptr;
301 ASSERT_THAT(
302 root,
303 GmockMatch(m::Pad(
304 &pad,
305 m::Reshape(m::Transpose(m::GetTupleElement(m::CustomCall(), 0))),
306 m::ConstantScalar(0))));
307
308 ExpectOnlyPadsOneDim(/*dim=*/1, /*padding_high=*/2, pad->padding_config());
309 }
310
TEST_F(CudnnSimplifyPaddingTest,PaddedConstantWeight)311 TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeight) {
312 auto module = ParseAndReturnVerifiedModule(R"(
313 HloModule TestModule
314
315 ENTRY TestComputation {
316 conv = (s8[10,10,10,10], u8[0]) custom-call(
317 s8[10,10,10,10] parameter(0),
318 s8[3,3,10,10] constant({...})
319 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
320 custom_call_target="__cudnn$convForward"
321 conv_result = get-tuple-element(conv), index=0
322 slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
323 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
324 }
325 )")
326 .ValueOrDie();
327
328 // Set the constant's value. (The HLO text above sets it to all 0s.)
329 {
330 HloInstruction* weights = nullptr;
331 ASSERT_THAT(module->entry_computation()->root_instruction(),
332 GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
333 m::Op(), m::Constant(&weights)))),
334 m::Op())));
335 SetConstantValue<int8_t>(
336 weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
337 if (dims[3] < 6) return 1;
338 return 0;
339 });
340 }
341
342 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
343 EXPECT_TRUE(changed);
344
345 SCOPED_TRACE(module->ToString());
346 auto* root = module->entry_computation()->root_instruction();
347 const HloInstruction* pad = nullptr;
348 ASSERT_THAT(root,
349 GmockMatch(m::Pad(&pad, m::GetTupleElement(m::CustomCall(), 0),
350 m::ConstantScalar(0))));
351
352 ExpectOnlyPadsOneDim(/*dim=*/3, /*padding_high=*/1, pad->padding_config());
353 }
354
TEST_F(CudnnSimplifyPaddingTest,PaddedConstantWeightIsNotLargeEnough)355 TEST_F(CudnnSimplifyPaddingTest, PaddedConstantWeightIsNotLargeEnough) {
356 auto module = ParseAndReturnVerifiedModule(R"(
357 HloModule TestModule
358
359 ENTRY TestComputation {
360 conv = (s8[10,10,10,10], u8[0]) custom-call(
361 s8[10,10,10,10] parameter(0),
362 s8[3,3,10,10] constant({...})
363 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
364 custom_call_target="__cudnn$convForward"
365 conv_result = get-tuple-element(conv), index=0
366 slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
367 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
368 }
369 )")
370 .ValueOrDie();
371
372 // Set the constant's value. (The HLO text above sets it to all 0s.)
373 {
374 HloInstruction* weights = nullptr;
375 ASSERT_THAT(module->entry_computation()->root_instruction(),
376 GmockMatch(m::Pad(m::Slice(m::GetTupleElement(m::CustomCall(
377 m::Op(), m::Constant(&weights)))),
378 m::Op())));
379 SetConstantValue<int8_t>(
380 weights, [](absl::Span<const int64_t> dims, int8_t old_val) -> int8_t {
381 // The sixth feature dimension (i.e. index 5) is only partially 0.
382 if (dims[3] < 5 /*|| (dims[3] == 5 && dims[2] > 1)*/) return 0;
383 return 1;
384 });
385 }
386
387 // Some of the value sliced off are not 0, so we can't merge the slice into
388 // the pad.
389 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
390 EXPECT_FALSE(changed);
391 }
392
TEST_F(CudnnSimplifyPaddingTest,ReshapeDoesntMergeVectCDim)393 TEST_F(CudnnSimplifyPaddingTest, ReshapeDoesntMergeVectCDim) {
394 auto module = ParseAndReturnVerifiedModule(R"(
395 HloModule TestModule
396
397 ENTRY TestComputation {
398 weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
399 weights = s8[2,64,3,3,32] reshape(weights_p)
400 conv = (s8[10,2,10,10,32], u8[0]) custom-call(
401 s8[10,2,10,10,32] parameter(1),
402 weights_p
403 ), window={size=3x3}, dim_labels=bf01?_io01?->bf01?,
404 custom_call_target="__cudnn$convForward"
405 conv_result = get-tuple-element(conv), index=0
406 slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_result)), slice={[0:10], [0:60], [0:10], [0:10]}
407 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
408 }
409 )")
410 .ValueOrDie();
411
412 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
413 EXPECT_FALSE(changed);
414 }
415
TEST_F(CudnnSimplifyPaddingTest,TwoVectCDimsInOutput)416 TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInOutput) {
417 auto module = ParseAndReturnVerifiedModule(R"(
418 HloModule TestModule
419
420 ENTRY TestComputation {
421 weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
422 weights = s8[2,64,3,3,32] reshape(weights_p)
423 conv = (s8[10,2,10,10,4,8], u8[0]) custom-call(
424 s8[10,2,10,10,32] parameter(1),
425 weights
426 ), window={size=3x3}, dim_labels=bf01?_io01?->bf01??,
427 custom_call_target="__cudnn$convForward"
428 conv_result = get-tuple-element(conv), index=0
429 conv_transposed = s8[10,2,4,8,10,10] transpose(conv_result), dimensions={0,1,4,5,2,3}
430 slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
431 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
432 }
433 )")
434 .ValueOrDie();
435
436 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
437 EXPECT_FALSE(changed);
438 }
439
TEST_F(CudnnSimplifyPaddingTest,TwoVectCDimsInKernel)440 TEST_F(CudnnSimplifyPaddingTest, TwoVectCDimsInKernel) {
441 auto module = ParseAndReturnVerifiedModule(R"(
442 HloModule TestModule
443
444 ENTRY TestComputation {
445 weights_p = pad(s8[64,60,3,3] parameter(0), s8[] constant(0)), padding=0_0x0_4x0_0x0_0
446 weights = s8[2,64,3,3,4,8] reshape(weights_p)
447 conv = (s8[10,2,10,10,32], u8[0]) custom-call(
448 s8[10,2,10,10,32] parameter(1),
449 weights
450 ), window={size=3x3}, dim_labels=bf01?_io01??->bf01?,
451 custom_call_target="__cudnn$convForward"
452 conv_result = get-tuple-element(conv), index=0
453 conv_transposed = s8[10,2,32,10,10] transpose(conv_result), dimensions={0,1,4,2,3}
454 slice = s8[10,60,10,10] slice(s8[10,64,10,10] reshape(conv_transposed)), slice={[0:10], [0:60], [0:10], [0:10]}
455 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_6x0_0x0_0
456 }
457 )")
458 .ValueOrDie();
459
460 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
461 EXPECT_FALSE(changed);
462 }
463
TEST_F(CudnnSimplifyPaddingTest,SliceDoesntStartAtBeginning)464 TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginning) {
465 auto module = ParseAndReturnVerifiedModule(R"(
466 HloModule TestModule
467
468 ENTRY TestComputation {
469 weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
470 conv = (s8[10,10,10,10], u8[0]) custom-call(
471 s8[10,10,10,10] parameter(1),
472 weights
473 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
474 custom_call_target="__cudnn$convForward"
475 conv_result = get-tuple-element(conv), index=0
476 slice = s8[10,9,10,6] slice(conv_result), slice={[0:10], [1:10], [0:10], [0:6]}
477 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
478 }
479 )")
480 .ValueOrDie();
481
482 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
483 EXPECT_FALSE(changed);
484 }
485
TEST_F(CudnnSimplifyPaddingTest,SliceDoesntStartAtBeginningOfFeatureDim)486 TEST_F(CudnnSimplifyPaddingTest, SliceDoesntStartAtBeginningOfFeatureDim) {
487 auto module = ParseAndReturnVerifiedModule(R"(
488 HloModule TestModule
489
490 ENTRY TestComputation {
491 weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
492 conv = (s8[10,10,10,10], u8[0]) custom-call(
493 s8[10,10,10,10] parameter(1),
494 weights
495 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
496 custom_call_target="__cudnn$convForward"
497 conv_result = get-tuple-element(conv), index=0
498 slice = s8[10,10,10,5] slice(conv_result), slice={[0:10], [0:10], [0:10], [1:6]}
499 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
500 }
501 )")
502 .ValueOrDie();
503
504 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
505 EXPECT_FALSE(changed);
506 }
507
TEST_F(CudnnSimplifyPaddingTest,SliceHasStride)508 TEST_F(CudnnSimplifyPaddingTest, SliceHasStride) {
509 auto module = ParseAndReturnVerifiedModule(R"(
510 HloModule TestModule
511
512 ENTRY TestComputation {
513 weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
514 conv = (s8[10,10,10,10], u8[0]) custom-call(
515 s8[10,10,10,10] parameter(1),
516 weights
517 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
518 custom_call_target="__cudnn$convForward"
519 conv_result = get-tuple-element(conv), index=0
520 slice = s8[10,10,10,3] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6:2]}
521 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5
522 }
523 )")
524 .ValueOrDie();
525
526 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
527 EXPECT_FALSE(changed);
528 }
529
TEST_F(CudnnSimplifyPaddingTest,PadAddsInteriorPadding)530 TEST_F(CudnnSimplifyPaddingTest, PadAddsInteriorPadding) {
531 auto module = ParseAndReturnVerifiedModule(R"(
532 HloModule TestModule
533
534 ENTRY TestComputation {
535 weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
536 conv = (s8[10,10,10,10], u8[0]) custom-call(
537 s8[10,10,10,10] parameter(1),
538 weights
539 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
540 custom_call_target="__cudnn$convForward"
541 conv_result = get-tuple-element(conv), index=0
542 slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
543 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_5_1
544 }
545 )")
546 .ValueOrDie();
547
548 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
549 EXPECT_FALSE(changed);
550 }
551
TEST_F(CudnnSimplifyPaddingTest,SliceMoreElementsThanPad)552 TEST_F(CudnnSimplifyPaddingTest, SliceMoreElementsThanPad) {
553 auto module = ParseAndReturnVerifiedModule(R"(
554 HloModule TestModule
555
556 ENTRY TestComputation {
557 weights = pad(s8[3,3,10,10] parameter(0), s8[] constant(0)), padding=0_0x0_0x0_0x0_4
558 conv = (s8[10,10,10,10], u8[0]) custom-call(
559 s8[10,10,10,10] parameter(1),
560 weights
561 ), window={size=3x3}, dim_labels=b01f_01io->b01f,
562 custom_call_target="__cudnn$convForward"
563 conv_result = get-tuple-element(conv), index=0
564 slice = s8[10,10,10,6] slice(conv_result), slice={[0:10], [0:10], [0:10], [0:6]}
565 ROOT pad = pad(slice, s8[] constant(0)), padding=0_0x0_0x0_0x0_2
566 }
567 )")
568 .ValueOrDie();
569
570 TF_ASSERT_OK_AND_ASSIGN(bool changed, RunJustThisPass(module.get()));
571 EXPECT_TRUE(changed);
572
573 SCOPED_TRACE(module->ToString());
574 auto* root = module->entry_computation()->root_instruction();
575 const HloInstruction* slice = nullptr;
576 // The pass creates a pad with negative padding; this is simplified by algsimp
577 // into a slice.
578 ASSERT_THAT(root, GmockMatch(m::Slice(
579 &slice, m::GetTupleElement(m::CustomCall(), 0))));
580 for (int64_t i = 0; i < slice->shape().dimensions_size(); ++i) {
581 SCOPED_TRACE(i);
582 EXPECT_EQ(slice->slice_starts(i), 0);
583 EXPECT_EQ(slice->slice_strides(i), 1);
584 if (i != 3) {
585 EXPECT_EQ(slice->slice_limits(i), 10);
586 } else {
587 EXPECT_EQ(slice->slice_limits(i), 8);
588 }
589 }
590 }
591
592 } // anonymous namespace
593 } // namespace xla::gpu
594