1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
17
18 #include "tensorflow/compiler/xla/protobuf_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/shape_inference.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace xla {
32 namespace gpu {
33 namespace {
34
35 namespace op = xla::testing::opcode_matchers;
36 using ::testing::_;
37
38 class GpuConvRewriterTest : public HloTestBase {
39 public:
GpuConvRewriterTest()40 GpuConvRewriterTest()
41 : HloTestBase(/*layout_sensitive=*/true,
42 /*allow_mixed_precision=*/false) {
43 for (int i = 0; i < 2; ++i) {
44 WindowDimension* window_dim = default_conv_window_.add_dimensions();
45 window_dim->set_size(1);
46 window_dim->set_stride(1);
47 window_dim->set_padding_low(0);
48 window_dim->set_padding_high(0);
49 window_dim->set_window_dilation(1);
50 window_dim->set_base_dilation(1);
51 }
52 // TF data shapes are by default in the NHWC order, and filter shape is by
53 // default in HWIO order. For backward filter convolution, we need to swap
54 // the batch and feature dimension in the activations, and treat the batch
55 // dimension in gradients as the input feature dimension in the filter.
56 //
57 // TODO(jingyue): Add more tests on NCHW input order, which TF also
58 // supports.
59 tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
60 tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
61 tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
62 tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
63 tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
64 tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
65 3);
66 tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
67 tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
68 tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
69 tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
70 tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
71 tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
72
73 tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
74 tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
75 tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
76 tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
77 tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
78 tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
79 tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
80 tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
81 tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
82 tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
83 tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
84 tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
85 }
86
87 protected:
RunPass(HloModule * module)88 bool RunPass(HloModule* module) {
89 return GpuConvRewriter().Run(module).ValueOrDie();
90 }
91
92 // A convolution window with stride 1 and zero padding. The size fields are
93 // not set.
94 Window default_conv_window_;
95 ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
96 ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
97 };
98
TEST_F(GpuConvRewriterTest,BackwardFilterConvolve)99 TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) {
100 HloComputation::Builder builder(TestName());
101 HloInstruction* activations =
102 builder.AddInstruction(HloInstruction::CreateParameter(
103 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
104 HloInstruction* gradients =
105 builder.AddInstruction(HloInstruction::CreateParameter(
106 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
107 Window conv_window = default_conv_window_;
108 conv_window.mutable_dimensions(1)->set_size(2);
109 conv_window.mutable_dimensions(1)->set_window_dilation(2);
110 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
111 ShapeInference::InferConvolveShape(
112 activations->shape(), gradients->shape(), /*feature_group_count=*/1,
113 /*batch_group_count=*/1, conv_window,
114 tf_default_dnums_for_backward_filter_,
115 /*preferred_element_type=*/std::nullopt)
116 .value(),
117 activations, gradients, /*feature_group_count=*/1,
118 /*batch_group_count=*/1, conv_window,
119 tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
120
121 OpMetadata metadata;
122 metadata.set_op_name("foo");
123 conv->set_metadata(metadata);
124
125 auto module = CreateNewVerifiedModule();
126 HloComputation* entry_computation =
127 module->AddEntryComputation(builder.Build());
128 EXPECT_TRUE(RunPass(module.get()));
129 ASSERT_THAT(entry_computation->root_instruction(),
130 op::GetTupleElement(
131 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
132
133 // Check that metadata was preserved.
134 const auto& md_after_opt =
135 entry_computation->root_instruction()->operand(0)->metadata();
136 EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
137 << md_after_opt.DebugString() << " vs " << metadata.DebugString();
138 }
139
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveEquivalentToForwardConvolution)140 TEST_F(GpuConvRewriterTest,
141 BackwardFilterConvolveEquivalentToForwardConvolution) {
142 HloComputation::Builder builder(TestName());
143 HloInstruction* activations =
144 builder.AddInstruction(HloInstruction::CreateParameter(
145 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
146 HloInstruction* gradients =
147 builder.AddInstruction(HloInstruction::CreateParameter(
148 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
149 Window conv_window = default_conv_window_;
150 conv_window.mutable_dimensions(1)->set_size(3);
151 builder.AddInstruction(HloInstruction::CreateConvolve(
152 ShapeInference::InferConvolveShape(
153 activations->shape(), gradients->shape(), /*feature_group_count=*/1,
154 /*batch_group_count=*/1, conv_window,
155 tf_default_dnums_for_backward_filter_,
156 /*preferred_element_type=*/std::nullopt)
157 .value(),
158 activations, gradients, /*feature_group_count=*/1,
159 /*batch_group_count=*/1, conv_window,
160 tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
161
162 auto module = CreateNewVerifiedModule();
163 HloComputation* entry_computation =
164 module->AddEntryComputation(builder.Build());
165 EXPECT_TRUE(RunPass(module.get()));
166 EXPECT_THAT(
167 entry_computation->root_instruction(),
168 op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
169 }
170
171 // Extracted from block35 training.
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveWithPaddedActivations)172 TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) {
173 auto builder = HloComputation::Builder(TestName());
174 HloInstruction* activations =
175 builder.AddInstruction(HloInstruction::CreateParameter(
176 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
177 HloInstruction* gradients =
178 builder.AddInstruction(HloInstruction::CreateParameter(
179 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
180
181 Window conv_window = default_conv_window_;
182 for (int i = 0; i < 2; ++i) {
183 conv_window.mutable_dimensions(i)->set_size(35);
184 conv_window.mutable_dimensions(i)->set_padding_low(1);
185 conv_window.mutable_dimensions(i)->set_padding_high(1);
186 }
187 builder.AddInstruction(HloInstruction::CreateConvolve(
188 ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
189 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
190 tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
191
192 auto module = CreateNewVerifiedModule();
193 HloComputation* entry_computation =
194 module->AddEntryComputation(builder.Build());
195 EXPECT_TRUE(RunPass(module.get()));
196 EXPECT_THAT(entry_computation->root_instruction(),
197 op::GetTupleElement(
198 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
199 }
200
201 // Extracted from inception v3 training.
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveWithPaddedGradients)202 TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) {
203 auto builder = HloComputation::Builder(TestName());
204 HloInstruction* activations =
205 builder.AddInstruction(HloInstruction::CreateParameter(
206 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
207 HloInstruction* gradients =
208 builder.AddInstruction(HloInstruction::CreateParameter(
209 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
210
211 Window conv_window = default_conv_window_;
212 for (int i = 0; i < 2; ++i) {
213 conv_window.mutable_dimensions(i)->set_size(4);
214 conv_window.mutable_dimensions(i)->set_padding_high(-1);
215 conv_window.mutable_dimensions(i)->set_window_dilation(2);
216 }
217 builder.AddInstruction(HloInstruction::CreateConvolve(
218 ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
219 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
220 tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
221
222 auto module = CreateNewVerifiedModule();
223 HloComputation* entry_computation =
224 module->AddEntryComputation(builder.Build());
225 EXPECT_TRUE(RunPass(module.get()));
226 EXPECT_THAT(entry_computation->root_instruction(),
227 op::GetTupleElement(
228 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
229 }
230
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveWithUnevenPadding)231 TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
232 auto builder = HloComputation::Builder(TestName());
233 HloInstruction* activations =
234 builder.AddInstruction(HloInstruction::CreateParameter(
235 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
236 HloInstruction* gradients =
237 builder.AddInstruction(HloInstruction::CreateParameter(
238 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
239
240 Window conv_window = default_conv_window_;
241 for (int i = 0; i < 2; ++i) {
242 conv_window.mutable_dimensions(i)->set_size(35);
243 // Uneven padding: padding_low=0, padding_high=1
244 conv_window.mutable_dimensions(i)->set_padding_high(1);
245 }
246 builder.AddInstruction(HloInstruction::CreateConvolve(
247 ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
248 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
249 tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
250
251 auto module = CreateNewVerifiedModule();
252 HloComputation* entry_computation =
253 module->AddEntryComputation(builder.Build());
254 EXPECT_TRUE(RunPass(module.get()));
255 EXPECT_THAT(entry_computation->root_instruction(),
256 op::GetTupleElement(
257 op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
258 }
259
TEST_F(GpuConvRewriterTest,BackwardInputConvolveEvenPadding)260 TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) {
261 auto builder = HloComputation::Builder(TestName());
262 HloInstruction* output =
263 builder.AddInstruction(HloInstruction::CreateParameter(
264 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
265 HloInstruction* kernel =
266 builder.AddInstruction(HloInstruction::CreateParameter(
267 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
268 HloInstruction* reverse_kernel = builder.AddInstruction(
269 HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
270
271 Window conv_window = default_conv_window_;
272 for (int i = 0; i < 2; ++i) {
273 conv_window.mutable_dimensions(i)->set_size(7);
274 conv_window.mutable_dimensions(i)->set_padding_low(3);
275 conv_window.mutable_dimensions(i)->set_padding_high(3);
276 }
277 ConvolutionDimensionNumbers conv_dnums;
278 conv_dnums.set_input_batch_dimension(0);
279 conv_dnums.set_output_batch_dimension(0);
280 conv_dnums.set_input_feature_dimension(1);
281 conv_dnums.set_output_feature_dimension(1);
282 conv_dnums.add_input_spatial_dimensions(2);
283 conv_dnums.add_output_spatial_dimensions(2);
284 conv_dnums.add_input_spatial_dimensions(3);
285 conv_dnums.add_output_spatial_dimensions(3);
286 conv_dnums.set_kernel_input_feature_dimension(0);
287 conv_dnums.set_kernel_output_feature_dimension(1);
288 conv_dnums.add_kernel_spatial_dimensions(2);
289 conv_dnums.add_kernel_spatial_dimensions(3);
290
291 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
292 ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
293 /*rhs=*/reverse_kernel, /*feature_group_count=*/1,
294 /*batch_group_count=*/1, conv_window, conv_dnums,
295 DefaultPrecisionConfig(2)));
296 // Verify the convolution's shape is consistent with ShapeInference.
297 CHECK(ShapeUtil::Compatible(
298 conv->shape(),
299 ShapeInference::InferConvolveShape(
300 output->shape(), reverse_kernel->shape(),
301 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
302 conv_dnums, /*preferred_element_type=*/std::nullopt)
303 .ValueOrDie()));
304
305 auto module = CreateNewVerifiedModule();
306 HloComputation* entry_computation =
307 module->AddEntryComputation(builder.Build());
308 EXPECT_TRUE(RunPass(module.get()));
309
310 ASSERT_THAT(entry_computation->root_instruction(),
311 op::GetTupleElement(
312 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
313 const HloInstruction* custom_call =
314 entry_computation->root_instruction()->operand(0);
315 for (int i = 0; i < 2; ++i) {
316 const WindowDimension& window_dim = custom_call->window().dimensions(i);
317 // Low padding of the backward input convolution
318 // = kernel_size - 1 - low padding on gradients.
319 EXPECT_EQ(3, window_dim.padding_low());
320 EXPECT_EQ(3, window_dim.padding_high());
321 EXPECT_EQ(1, window_dim.stride());
322 EXPECT_EQ(1, window_dim.base_dilation());
323 }
324 }
325
326 // Convolve([abc], [x], base_dilation=2)
327 // = Convolve([abc], Reverse([x]), base_dilation=2)
328 // = BackwardInputConvolve([abc], [x], stride=2)
TEST_F(GpuConvRewriterTest,BackwardInputConvolve1x1Filter)329 TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) {
330 auto builder = HloComputation::Builder(TestName());
331 // NHWC dimension order.
332 HloInstruction* output =
333 builder.AddInstruction(HloInstruction::CreateParameter(
334 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
335 // HWOI dimension order.
336 HloInstruction* kernel =
337 builder.AddInstruction(HloInstruction::CreateParameter(
338 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
339
340 Window conv_window = default_conv_window_;
341 conv_window.mutable_dimensions(1)->set_base_dilation(2);
342
343 builder.AddInstruction(HloInstruction::CreateConvolve(
344 ShapeInference::InferConvolveShape(
345 output->shape(), kernel->shape(),
346 /*feature_group_count=*/1,
347 /*batch_group_count=*/1, conv_window,
348 tf_default_dnums_for_backward_input_,
349 /*preferred_element_type=*/std::nullopt)
350 .value(),
351 /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
352 /*batch_group_count=*/1, conv_window,
353 tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
354
355 auto module = CreateNewVerifiedModule();
356 HloComputation* entry_computation =
357 module->AddEntryComputation(builder.Build());
358 EXPECT_TRUE(RunPass(module.get()));
359 EXPECT_THAT(entry_computation->root_instruction(),
360 op::GetTupleElement(
361 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
362 }
363
364 // BackwardInputConvolve([abc], [x], stride=1) is equivalent to
365 // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
366 // convolution.
TEST_F(GpuConvRewriterTest,BackwardInputConvolve1x1FilterEquivalentToForwardConvolve)367 TEST_F(GpuConvRewriterTest,
368 BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
369 auto builder = HloComputation::Builder(TestName());
370 // NHWC dimension order.
371 HloInstruction* output =
372 builder.AddInstruction(HloInstruction::CreateParameter(
373 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
374 // HWOI dimension order.
375 HloInstruction* kernel =
376 builder.AddInstruction(HloInstruction::CreateParameter(
377 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
378
379 builder.AddInstruction(HloInstruction::CreateConvolve(
380 ShapeInference::InferConvolveShape(
381 output->shape(), kernel->shape(), /*feature_group_count=*/1,
382 /*batch_group_count=*/1, default_conv_window_,
383 tf_default_dnums_for_backward_input_,
384 /*preferred_element_type=*/std::nullopt)
385 .value(),
386 /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
387 /*batch_group_count=*/1, default_conv_window_,
388 tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
389
390 auto module = CreateNewVerifiedModule();
391 HloComputation* entry_computation =
392 module->AddEntryComputation(builder.Build());
393 EXPECT_TRUE(RunPass(module.get()));
394 EXPECT_THAT(
395 entry_computation->root_instruction(),
396 op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
397 }
398
399 // Extracted from Inception V3 training.
400 //
401 // filter(HWIO)
402 // 3x3x192x320
403 // |
404 // v
405 // gradients(NHWC) reverse
406 // 20x4x4x320 3x3x192x320
407 // \ /
408 // \ /
409 // conv (NHWC) with padding (low=2,high=3,interior=1)
410 // 20x10x10x192
411 //
412 // Gradients are padded unevenly.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveUnevenPaddingOnGradients)413 TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
414 auto builder = HloComputation::Builder(TestName());
415 HloInstruction* output =
416 builder.AddInstruction(HloInstruction::CreateParameter(
417 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
418 HloInstruction* kernel =
419 builder.AddInstruction(HloInstruction::CreateParameter(
420 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
421 HloInstruction* reverse_kernel = builder.AddInstruction(
422 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
423
424 Window conv_window = default_conv_window_;
425 for (int i = 0; i < 2; ++i) {
426 conv_window.mutable_dimensions(i)->set_size(3);
427 conv_window.mutable_dimensions(i)->set_padding_low(2);
428 conv_window.mutable_dimensions(i)->set_padding_high(3);
429 // Interior padding = 1.
430 conv_window.mutable_dimensions(i)->set_base_dilation(2);
431 }
432 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
433 ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
434 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
435 tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
436 // Verify the convolution's shape is consistent with ShapeInference.
437 CHECK(ShapeUtil::Compatible(
438 conv->shape(), ShapeInference::InferConvolveShape(
439 output->shape(), reverse_kernel->shape(),
440 /*feature_group_count=*/1, /*batch_group_count=*/1,
441 conv_window, tf_default_dnums_for_backward_input_,
442 /*preferred_element_type=*/std::nullopt)
443 .ValueOrDie()));
444
445 auto module = CreateNewVerifiedModule();
446 HloComputation* entry_computation =
447 module->AddEntryComputation(builder.Build());
448 EXPECT_TRUE(RunPass(module.get()));
449 ASSERT_THAT(entry_computation->root_instruction(),
450 op::GetTupleElement(
451 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
452 const HloInstruction* custom_call =
453 entry_computation->root_instruction()->operand(0);
454 for (int i = 0; i < 2; ++i) {
455 const WindowDimension& window_dim = custom_call->window().dimensions(i);
456 EXPECT_EQ(0, window_dim.padding_low());
457 EXPECT_EQ(0, window_dim.padding_high());
458 EXPECT_EQ(2, window_dim.stride());
459 EXPECT_EQ(1, window_dim.base_dilation());
460 }
461 }
462
463 // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
464 // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveLowPaddingTooLarge)465 TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
466 auto builder = HloComputation::Builder(TestName());
467 HloInstruction* output =
468 builder.AddInstruction(HloInstruction::CreateParameter(
469 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
470 HloInstruction* kernel =
471 builder.AddInstruction(HloInstruction::CreateParameter(
472 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
473 HloInstruction* reverse_kernel = builder.AddInstruction(
474 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
475
476 Window conv_window = default_conv_window_;
477 for (int i = 0; i < 2; ++i) {
478 conv_window.mutable_dimensions(i)->set_size(3);
479 conv_window.mutable_dimensions(i)->set_padding_low(3);
480 conv_window.mutable_dimensions(i)->set_padding_high(2);
481 conv_window.mutable_dimensions(i)->set_base_dilation(2);
482 }
483 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
484 ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
485 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
486 tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
487 // Verify the convolution's shape is consistent with ShapeInference.
488 CHECK(ShapeUtil::Compatible(
489 conv->shape(), ShapeInference::InferConvolveShape(
490 output->shape(), reverse_kernel->shape(),
491 /*feature_group_count=*/1, /*batch_group_count=*/1,
492 conv_window, tf_default_dnums_for_backward_input_,
493 /*preferred_element_type=*/std::nullopt)
494 .ValueOrDie()));
495
496 auto module = CreateNewVerifiedModule();
497 HloComputation* entry_computation =
498 module->AddEntryComputation(builder.Build());
499 EXPECT_TRUE(RunPass(module.get()));
500 EXPECT_THAT(
501 entry_computation->root_instruction(),
502 op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
503 }
504
505 // Extracted from Resnet-50.
506 //
507 // For simplicity, we focus on the column dimension and ignore other dimensions.
508 // We use [?] to represent the shape instead of the content.
509 //
510 // Suppose operator FC does
511 // [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame
512 //
513 // BC = BackwardInput(FC) does:
514 // [14] = conv([7], reverse([3]),
515 // padding_low=2, padding_high=1, base_dilation=2)
516 //
517 // We should fuse BC even though padding on activations is uneven, because
518 // GpuConvPaddingLegalization will canonicalize the fusion HLO.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveUnevenPaddingOnActivations)519 TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
520 auto builder = HloComputation::Builder(TestName());
521 // The gradients are in NCHW layout.
522 HloInstruction* output =
523 builder.AddInstruction(HloInstruction::CreateParameter(
524 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
525 // The kernel is in HWIO layout.
526 HloInstruction* kernel =
527 builder.AddInstruction(HloInstruction::CreateParameter(
528 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
529 HloInstruction* reverse_kernel = builder.AddInstruction(
530 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
531
532 Window conv_window = default_conv_window_;
533 WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
534 forward_conv_col_dim->set_size(3);
535 forward_conv_col_dim->set_padding_low(2);
536 forward_conv_col_dim->set_padding_high(1);
537 forward_conv_col_dim->set_base_dilation(2);
538 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
539 ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
540 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
541 tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
542 // Verify the convolution's shape is consistent with ShapeInference.
543 CHECK(ShapeUtil::Compatible(
544 conv->shape(), ShapeInference::InferConvolveShape(
545 output->shape(), reverse_kernel->shape(),
546 /*feature_group_count=*/1, /*batch_group_count=*/1,
547 conv_window, tf_default_dnums_for_backward_input_,
548 /*preferred_element_type=*/std::nullopt)
549 .ValueOrDie()));
550
551 auto module = CreateNewVerifiedModule();
552 const HloComputation* entry_computation =
553 module->AddEntryComputation(builder.Build());
554 EXPECT_TRUE(RunPass(module.get()));
555 ASSERT_THAT(entry_computation->root_instruction(),
556 op::GetTupleElement(
557 op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
558 const WindowDimension& backward_conv_col_dim =
559 entry_computation->root_instruction()->operand(0)->window().dimensions(1);
560 EXPECT_EQ(0, backward_conv_col_dim.padding_low());
561 EXPECT_EQ(1, backward_conv_col_dim.padding_high());
562 }
563
564 // For simplicity, we focus on the column dimension and ignore other dimensions.
565 // We use [?] to represent the shape instead of the content.
566 //
567 // Suppose operator FC does
568 // [3] = conv([4], [2], padding_low=1, padding_high=-1)
569 //
570 // BC = BackwardInput(FC) does:
571 // [4] = conv([3], reverse([2]), padding_high=2)
572 //
573 // We currently don't fuse BC because GpuConvPaddingLegalization
574 // doesn't support negative padding on the gradients of backward convolution
575 // (b/32744257).
TEST_F(GpuConvRewriterTest,BackwardInputConvolveNegativePaddingHighOnActivations)576 TEST_F(GpuConvRewriterTest,
577 BackwardInputConvolveNegativePaddingHighOnActivations) {
578 auto builder = HloComputation::Builder(TestName());
579 // The gradients are in NCHW layout.
580 HloInstruction* output =
581 builder.AddInstruction(HloInstruction::CreateParameter(
582 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
583 // The kernel is in HWIO layout.
584 HloInstruction* kernel =
585 builder.AddInstruction(HloInstruction::CreateParameter(
586 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
587 HloInstruction* reverse_kernel = builder.AddInstruction(
588 HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
589
590 Window conv_window = default_conv_window_;
591 WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
592 forward_conv_col_dim->set_size(2);
593 forward_conv_col_dim->set_padding_high(2);
594 HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
595 ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
596 /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
597 tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
598 // Verify the convolution's shape is consistent with ShapeInference.
599 CHECK(ShapeUtil::Compatible(
600 conv->shape(), ShapeInference::InferConvolveShape(
601 output->shape(), reverse_kernel->shape(),
602 /*feature_group_count=*/1, /*batch_group_count=*/1,
603 conv_window, tf_default_dnums_for_backward_input_,
604 /*preferred_element_type=*/std::nullopt)
605 .ValueOrDie()));
606
607 auto module = CreateNewVerifiedModule();
608 HloComputation* entry_computation =
609 module->AddEntryComputation(builder.Build());
610 EXPECT_TRUE(RunPass(module.get()));
611 EXPECT_THAT(
612 entry_computation->root_instruction(),
613 op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
614 }
615
616 // Check that we will materialize a reversed version of a constant in order to
617 // pattern-match a backwards input convolution.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveConstantFilter)618 TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) {
619 Array4D<float> constant_arr(4, 4, 2, 2);
620 constant_arr.FillIota(0);
621 std::string constant_str =
622 LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape();
623
624 const std::string module_str = absl::StrFormat(R"(
625 HloModule test
626
627 ENTRY entry_computation {
628 param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
629 constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
630 ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
631 window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
632 dim_labels=bf01_01oi->bf01, feature_group_count=1
633 })",
634 constant_str);
635 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
636
637 EXPECT_TRUE(RunPass(m.get()));
638 EXPECT_THAT(
639 m->entry_computation()->root_instruction(),
640 op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _,
641 op::Reverse(op::Constant())),
642 0));
643 }
644
TEST_F(GpuConvRewriterTest,TestBackwardFilterPattern)645 TEST_F(GpuConvRewriterTest, TestBackwardFilterPattern) {
646 const std::string module_str = absl::StrFormat(R"(
647 HloModule Test
648
649 ENTRY Test {
650 input = f32[8,120,256,256] parameter(0)
651 filter = f32[8,120,256,256] parameter(1)
652
653 ROOT conv = f32[120,120,3,3] convolution(input, filter), window={size=256x256 pad=1_1x1_1}, dim_labels=fb01_io01->fb01
654 })");
655 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
656
657 EXPECT_TRUE(RunPass(m.get()));
658 EXPECT_THAT(m->entry_computation()->root_instruction(),
659 op::GetTupleElement(
660 op::CustomCall(kCudnnConvBackwardFilterCallTarget, _, _), 0));
661 }
662
663 } // anonymous namespace
664 } // namespace gpu
665 } // namespace xla
666