• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
17 
18 #include "tensorflow/compiler/xla/protobuf_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/shape_inference.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 namespace xla {
32 namespace gpu {
33 namespace {
34 
35 namespace op = xla::testing::opcode_matchers;
36 using ::testing::_;
37 
38 class GpuConvRewriterTest : public HloTestBase {
39  public:
GpuConvRewriterTest()40   GpuConvRewriterTest()
41       : HloTestBase(/*layout_sensitive=*/true,
42                     /*allow_mixed_precision=*/false) {
43     for (int i = 0; i < 2; ++i) {
44       WindowDimension* window_dim = default_conv_window_.add_dimensions();
45       window_dim->set_size(1);
46       window_dim->set_stride(1);
47       window_dim->set_padding_low(0);
48       window_dim->set_padding_high(0);
49       window_dim->set_window_dilation(1);
50       window_dim->set_base_dilation(1);
51     }
52     // TF data shapes are by default in the NHWC order, and filter shape is by
53     // default in HWIO order. For backward filter convolution, we need to swap
54     // the batch and feature dimension in the activations, and treat the batch
55     // dimension in gradients as the input feature dimension in the filter.
56     //
57     // TODO(jingyue): Add more tests on NCHW input order, which TF also
58     // supports.
59     tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
60     tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
61     tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
62     tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
63     tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
64     tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
65         3);
66     tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
67     tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
68     tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
69     tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
70     tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
71     tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
72 
73     tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
74     tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
75     tf_default_dnums_for_backward_input_.set_input_feature_dimension(3);
76     tf_default_dnums_for_backward_input_.set_output_feature_dimension(3);
77     tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1);
78     tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1);
79     tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2);
80     tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2);
81     tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
82     tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
83     tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
84     tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
85   }
86 
87  protected:
RunPass(HloModule * module)88   bool RunPass(HloModule* module) {
89     return GpuConvRewriter().Run(module).ValueOrDie();
90   }
91 
92   // A convolution window with stride 1 and zero padding. The size fields are
93   // not set.
94   Window default_conv_window_;
95   ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
96   ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
97 };
98 
TEST_F(GpuConvRewriterTest,BackwardFilterConvolve)99 TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) {
100   HloComputation::Builder builder(TestName());
101   HloInstruction* activations =
102       builder.AddInstruction(HloInstruction::CreateParameter(
103           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
104   HloInstruction* gradients =
105       builder.AddInstruction(HloInstruction::CreateParameter(
106           1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
107   Window conv_window = default_conv_window_;
108   conv_window.mutable_dimensions(1)->set_size(2);
109   conv_window.mutable_dimensions(1)->set_window_dilation(2);
110   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
111       ShapeInference::InferConvolveShape(
112           activations->shape(), gradients->shape(), /*feature_group_count=*/1,
113           /*batch_group_count=*/1, conv_window,
114           tf_default_dnums_for_backward_filter_,
115           /*preferred_element_type=*/std::nullopt)
116           .value(),
117       activations, gradients, /*feature_group_count=*/1,
118       /*batch_group_count=*/1, conv_window,
119       tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
120 
121   OpMetadata metadata;
122   metadata.set_op_name("foo");
123   conv->set_metadata(metadata);
124 
125   auto module = CreateNewVerifiedModule();
126   HloComputation* entry_computation =
127       module->AddEntryComputation(builder.Build());
128   EXPECT_TRUE(RunPass(module.get()));
129   ASSERT_THAT(entry_computation->root_instruction(),
130               op::GetTupleElement(
131                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
132 
133   // Check that metadata was preserved.
134   const auto& md_after_opt =
135       entry_computation->root_instruction()->operand(0)->metadata();
136   EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata))
137       << md_after_opt.DebugString() << " vs " << metadata.DebugString();
138 }
139 
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveEquivalentToForwardConvolution)140 TEST_F(GpuConvRewriterTest,
141        BackwardFilterConvolveEquivalentToForwardConvolution) {
142   HloComputation::Builder builder(TestName());
143   HloInstruction* activations =
144       builder.AddInstruction(HloInstruction::CreateParameter(
145           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
146   HloInstruction* gradients =
147       builder.AddInstruction(HloInstruction::CreateParameter(
148           1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
149   Window conv_window = default_conv_window_;
150   conv_window.mutable_dimensions(1)->set_size(3);
151   builder.AddInstruction(HloInstruction::CreateConvolve(
152       ShapeInference::InferConvolveShape(
153           activations->shape(), gradients->shape(), /*feature_group_count=*/1,
154           /*batch_group_count=*/1, conv_window,
155           tf_default_dnums_for_backward_filter_,
156           /*preferred_element_type=*/std::nullopt)
157           .value(),
158       activations, gradients, /*feature_group_count=*/1,
159       /*batch_group_count=*/1, conv_window,
160       tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
161 
162   auto module = CreateNewVerifiedModule();
163   HloComputation* entry_computation =
164       module->AddEntryComputation(builder.Build());
165   EXPECT_TRUE(RunPass(module.get()));
166   EXPECT_THAT(
167       entry_computation->root_instruction(),
168       op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
169 }
170 
171 // Extracted from block35 training.
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveWithPaddedActivations)172 TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) {
173   auto builder = HloComputation::Builder(TestName());
174   HloInstruction* activations =
175       builder.AddInstruction(HloInstruction::CreateParameter(
176           0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
177   HloInstruction* gradients =
178       builder.AddInstruction(HloInstruction::CreateParameter(
179           1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
180 
181   Window conv_window = default_conv_window_;
182   for (int i = 0; i < 2; ++i) {
183     conv_window.mutable_dimensions(i)->set_size(35);
184     conv_window.mutable_dimensions(i)->set_padding_low(1);
185     conv_window.mutable_dimensions(i)->set_padding_high(1);
186   }
187   builder.AddInstruction(HloInstruction::CreateConvolve(
188       ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
189       /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
190       tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
191 
192   auto module = CreateNewVerifiedModule();
193   HloComputation* entry_computation =
194       module->AddEntryComputation(builder.Build());
195   EXPECT_TRUE(RunPass(module.get()));
196   EXPECT_THAT(entry_computation->root_instruction(),
197               op::GetTupleElement(
198                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
199 }
200 
201 // Extracted from inception v3 training.
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveWithPaddedGradients)202 TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) {
203   auto builder = HloComputation::Builder(TestName());
204   HloInstruction* activations =
205       builder.AddInstruction(HloInstruction::CreateParameter(
206           0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
207   HloInstruction* gradients =
208       builder.AddInstruction(HloInstruction::CreateParameter(
209           1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
210 
211   Window conv_window = default_conv_window_;
212   for (int i = 0; i < 2; ++i) {
213     conv_window.mutable_dimensions(i)->set_size(4);
214     conv_window.mutable_dimensions(i)->set_padding_high(-1);
215     conv_window.mutable_dimensions(i)->set_window_dilation(2);
216   }
217   builder.AddInstruction(HloInstruction::CreateConvolve(
218       ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
219       /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
220       tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
221 
222   auto module = CreateNewVerifiedModule();
223   HloComputation* entry_computation =
224       module->AddEntryComputation(builder.Build());
225   EXPECT_TRUE(RunPass(module.get()));
226   EXPECT_THAT(entry_computation->root_instruction(),
227               op::GetTupleElement(
228                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
229 }
230 
TEST_F(GpuConvRewriterTest,BackwardFilterConvolveWithUnevenPadding)231 TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
232   auto builder = HloComputation::Builder(TestName());
233   HloInstruction* activations =
234       builder.AddInstruction(HloInstruction::CreateParameter(
235           0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
236   HloInstruction* gradients =
237       builder.AddInstruction(HloInstruction::CreateParameter(
238           1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
239 
240   Window conv_window = default_conv_window_;
241   for (int i = 0; i < 2; ++i) {
242     conv_window.mutable_dimensions(i)->set_size(35);
243     // Uneven padding: padding_low=0, padding_high=1
244     conv_window.mutable_dimensions(i)->set_padding_high(1);
245   }
246   builder.AddInstruction(HloInstruction::CreateConvolve(
247       ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
248       /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
249       tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
250 
251   auto module = CreateNewVerifiedModule();
252   HloComputation* entry_computation =
253       module->AddEntryComputation(builder.Build());
254   EXPECT_TRUE(RunPass(module.get()));
255   EXPECT_THAT(entry_computation->root_instruction(),
256               op::GetTupleElement(
257                   op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
258 }
259 
TEST_F(GpuConvRewriterTest,BackwardInputConvolveEvenPadding)260 TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) {
261   auto builder = HloComputation::Builder(TestName());
262   HloInstruction* output =
263       builder.AddInstruction(HloInstruction::CreateParameter(
264           0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
265   HloInstruction* kernel =
266       builder.AddInstruction(HloInstruction::CreateParameter(
267           1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
268   HloInstruction* reverse_kernel = builder.AddInstruction(
269       HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
270 
271   Window conv_window = default_conv_window_;
272   for (int i = 0; i < 2; ++i) {
273     conv_window.mutable_dimensions(i)->set_size(7);
274     conv_window.mutable_dimensions(i)->set_padding_low(3);
275     conv_window.mutable_dimensions(i)->set_padding_high(3);
276   }
277   ConvolutionDimensionNumbers conv_dnums;
278   conv_dnums.set_input_batch_dimension(0);
279   conv_dnums.set_output_batch_dimension(0);
280   conv_dnums.set_input_feature_dimension(1);
281   conv_dnums.set_output_feature_dimension(1);
282   conv_dnums.add_input_spatial_dimensions(2);
283   conv_dnums.add_output_spatial_dimensions(2);
284   conv_dnums.add_input_spatial_dimensions(3);
285   conv_dnums.add_output_spatial_dimensions(3);
286   conv_dnums.set_kernel_input_feature_dimension(0);
287   conv_dnums.set_kernel_output_feature_dimension(1);
288   conv_dnums.add_kernel_spatial_dimensions(2);
289   conv_dnums.add_kernel_spatial_dimensions(3);
290 
291   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
292       ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
293       /*rhs=*/reverse_kernel, /*feature_group_count=*/1,
294       /*batch_group_count=*/1, conv_window, conv_dnums,
295       DefaultPrecisionConfig(2)));
296   // Verify the convolution's shape is consistent with ShapeInference.
297   CHECK(ShapeUtil::Compatible(
298       conv->shape(),
299       ShapeInference::InferConvolveShape(
300           output->shape(), reverse_kernel->shape(),
301           /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
302           conv_dnums, /*preferred_element_type=*/std::nullopt)
303           .ValueOrDie()));
304 
305   auto module = CreateNewVerifiedModule();
306   HloComputation* entry_computation =
307       module->AddEntryComputation(builder.Build());
308   EXPECT_TRUE(RunPass(module.get()));
309 
310   ASSERT_THAT(entry_computation->root_instruction(),
311               op::GetTupleElement(
312                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
313   const HloInstruction* custom_call =
314       entry_computation->root_instruction()->operand(0);
315   for (int i = 0; i < 2; ++i) {
316     const WindowDimension& window_dim = custom_call->window().dimensions(i);
317     // Low padding of the backward input convolution
318     //   = kernel_size - 1 - low padding on gradients.
319     EXPECT_EQ(3, window_dim.padding_low());
320     EXPECT_EQ(3, window_dim.padding_high());
321     EXPECT_EQ(1, window_dim.stride());
322     EXPECT_EQ(1, window_dim.base_dilation());
323   }
324 }
325 
326 // Convolve([abc], [x], base_dilation=2)
327 //   = Convolve([abc], Reverse([x]), base_dilation=2)
328 //   = BackwardInputConvolve([abc], [x], stride=2)
TEST_F(GpuConvRewriterTest,BackwardInputConvolve1x1Filter)329 TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) {
330   auto builder = HloComputation::Builder(TestName());
331   // NHWC dimension order.
332   HloInstruction* output =
333       builder.AddInstruction(HloInstruction::CreateParameter(
334           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
335   // HWOI dimension order.
336   HloInstruction* kernel =
337       builder.AddInstruction(HloInstruction::CreateParameter(
338           1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
339 
340   Window conv_window = default_conv_window_;
341   conv_window.mutable_dimensions(1)->set_base_dilation(2);
342 
343   builder.AddInstruction(HloInstruction::CreateConvolve(
344       ShapeInference::InferConvolveShape(
345           output->shape(), kernel->shape(),
346           /*feature_group_count=*/1,
347           /*batch_group_count=*/1, conv_window,
348           tf_default_dnums_for_backward_input_,
349           /*preferred_element_type=*/std::nullopt)
350           .value(),
351       /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
352       /*batch_group_count=*/1, conv_window,
353       tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
354 
355   auto module = CreateNewVerifiedModule();
356   HloComputation* entry_computation =
357       module->AddEntryComputation(builder.Build());
358   EXPECT_TRUE(RunPass(module.get()));
359   EXPECT_THAT(entry_computation->root_instruction(),
360               op::GetTupleElement(
361                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
362 }
363 
364 // BackwardInputConvolve([abc], [x], stride=1) is equivalent to
365 // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
366 // convolution.
TEST_F(GpuConvRewriterTest,BackwardInputConvolve1x1FilterEquivalentToForwardConvolve)367 TEST_F(GpuConvRewriterTest,
368        BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
369   auto builder = HloComputation::Builder(TestName());
370   // NHWC dimension order.
371   HloInstruction* output =
372       builder.AddInstruction(HloInstruction::CreateParameter(
373           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
374   // HWOI dimension order.
375   HloInstruction* kernel =
376       builder.AddInstruction(HloInstruction::CreateParameter(
377           1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
378 
379   builder.AddInstruction(HloInstruction::CreateConvolve(
380       ShapeInference::InferConvolveShape(
381           output->shape(), kernel->shape(), /*feature_group_count=*/1,
382           /*batch_group_count=*/1, default_conv_window_,
383           tf_default_dnums_for_backward_input_,
384           /*preferred_element_type=*/std::nullopt)
385           .value(),
386       /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
387       /*batch_group_count=*/1, default_conv_window_,
388       tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
389 
390   auto module = CreateNewVerifiedModule();
391   HloComputation* entry_computation =
392       module->AddEntryComputation(builder.Build());
393   EXPECT_TRUE(RunPass(module.get()));
394   EXPECT_THAT(
395       entry_computation->root_instruction(),
396       op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
397 }
398 
399 // Extracted from Inception V3 training.
400 //
401 //                                  filter(HWIO)
402 //                                  3x3x192x320
403 //                                      |
404 //                                      v
405 //      gradients(NHWC)              reverse
406 //        20x4x4x320               3x3x192x320
407 //                    \            /
408 //                     \          /
409 //  conv (NHWC) with padding (low=2,high=3,interior=1)
410 //                     20x10x10x192
411 //
412 // Gradients are padded unevenly.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveUnevenPaddingOnGradients)413 TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
414   auto builder = HloComputation::Builder(TestName());
415   HloInstruction* output =
416       builder.AddInstruction(HloInstruction::CreateParameter(
417           0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
418   HloInstruction* kernel =
419       builder.AddInstruction(HloInstruction::CreateParameter(
420           1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
421   HloInstruction* reverse_kernel = builder.AddInstruction(
422       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
423 
424   Window conv_window = default_conv_window_;
425   for (int i = 0; i < 2; ++i) {
426     conv_window.mutable_dimensions(i)->set_size(3);
427     conv_window.mutable_dimensions(i)->set_padding_low(2);
428     conv_window.mutable_dimensions(i)->set_padding_high(3);
429     // Interior padding = 1.
430     conv_window.mutable_dimensions(i)->set_base_dilation(2);
431   }
432   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
433       ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
434       /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
435       tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
436   // Verify the convolution's shape is consistent with ShapeInference.
437   CHECK(ShapeUtil::Compatible(
438       conv->shape(), ShapeInference::InferConvolveShape(
439                          output->shape(), reverse_kernel->shape(),
440                          /*feature_group_count=*/1, /*batch_group_count=*/1,
441                          conv_window, tf_default_dnums_for_backward_input_,
442                          /*preferred_element_type=*/std::nullopt)
443                          .ValueOrDie()));
444 
445   auto module = CreateNewVerifiedModule();
446   HloComputation* entry_computation =
447       module->AddEntryComputation(builder.Build());
448   EXPECT_TRUE(RunPass(module.get()));
449   ASSERT_THAT(entry_computation->root_instruction(),
450               op::GetTupleElement(
451                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
452   const HloInstruction* custom_call =
453       entry_computation->root_instruction()->operand(0);
454   for (int i = 0; i < 2; ++i) {
455     const WindowDimension& window_dim = custom_call->window().dimensions(i);
456     EXPECT_EQ(0, window_dim.padding_low());
457     EXPECT_EQ(0, window_dim.padding_high());
458     EXPECT_EQ(2, window_dim.stride());
459     EXPECT_EQ(1, window_dim.base_dilation());
460   }
461 }
462 
463 // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
464 // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveLowPaddingTooLarge)465 TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
466   auto builder = HloComputation::Builder(TestName());
467   HloInstruction* output =
468       builder.AddInstruction(HloInstruction::CreateParameter(
469           0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
470   HloInstruction* kernel =
471       builder.AddInstruction(HloInstruction::CreateParameter(
472           1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
473   HloInstruction* reverse_kernel = builder.AddInstruction(
474       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
475 
476   Window conv_window = default_conv_window_;
477   for (int i = 0; i < 2; ++i) {
478     conv_window.mutable_dimensions(i)->set_size(3);
479     conv_window.mutable_dimensions(i)->set_padding_low(3);
480     conv_window.mutable_dimensions(i)->set_padding_high(2);
481     conv_window.mutable_dimensions(i)->set_base_dilation(2);
482   }
483   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
484       ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
485       /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
486       tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
487   // Verify the convolution's shape is consistent with ShapeInference.
488   CHECK(ShapeUtil::Compatible(
489       conv->shape(), ShapeInference::InferConvolveShape(
490                          output->shape(), reverse_kernel->shape(),
491                          /*feature_group_count=*/1, /*batch_group_count=*/1,
492                          conv_window, tf_default_dnums_for_backward_input_,
493                          /*preferred_element_type=*/std::nullopt)
494                          .ValueOrDie()));
495 
496   auto module = CreateNewVerifiedModule();
497   HloComputation* entry_computation =
498       module->AddEntryComputation(builder.Build());
499   EXPECT_TRUE(RunPass(module.get()));
500   EXPECT_THAT(
501       entry_computation->root_instruction(),
502       op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
503 }
504 
505 // Extracted from Resnet-50.
506 //
507 // For simplicity, we focus on the column dimension and ignore other dimensions.
508 // We use [?] to represent the shape instead of the content.
509 //
510 // Suppose operator FC does
511 //   [4] = conv([14], [3], stride=2, padding_high=1)  // Padding::kSame
512 //
513 // BC = BackwardInput(FC) does:
514 //   [14] = conv([7], reverse([3]),
515 //               padding_low=2, padding_high=1, base_dilation=2)
516 //
517 // We should fuse BC even though padding on activations is uneven, because
518 // GpuConvPaddingLegalization will canonicalize the fusion HLO.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveUnevenPaddingOnActivations)519 TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
520   auto builder = HloComputation::Builder(TestName());
521   // The gradients are in NCHW layout.
522   HloInstruction* output =
523       builder.AddInstruction(HloInstruction::CreateParameter(
524           0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
525   // The kernel is in HWIO layout.
526   HloInstruction* kernel =
527       builder.AddInstruction(HloInstruction::CreateParameter(
528           1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
529   HloInstruction* reverse_kernel = builder.AddInstruction(
530       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
531 
532   Window conv_window = default_conv_window_;
533   WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
534   forward_conv_col_dim->set_size(3);
535   forward_conv_col_dim->set_padding_low(2);
536   forward_conv_col_dim->set_padding_high(1);
537   forward_conv_col_dim->set_base_dilation(2);
538   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
539       ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
540       /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
541       tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
542   // Verify the convolution's shape is consistent with ShapeInference.
543   CHECK(ShapeUtil::Compatible(
544       conv->shape(), ShapeInference::InferConvolveShape(
545                          output->shape(), reverse_kernel->shape(),
546                          /*feature_group_count=*/1, /*batch_group_count=*/1,
547                          conv_window, tf_default_dnums_for_backward_input_,
548                          /*preferred_element_type=*/std::nullopt)
549                          .ValueOrDie()));
550 
551   auto module = CreateNewVerifiedModule();
552   const HloComputation* entry_computation =
553       module->AddEntryComputation(builder.Build());
554   EXPECT_TRUE(RunPass(module.get()));
555   ASSERT_THAT(entry_computation->root_instruction(),
556               op::GetTupleElement(
557                   op::CustomCall(kCudnnConvBackwardInputCallTarget), 0));
558   const WindowDimension& backward_conv_col_dim =
559       entry_computation->root_instruction()->operand(0)->window().dimensions(1);
560   EXPECT_EQ(0, backward_conv_col_dim.padding_low());
561   EXPECT_EQ(1, backward_conv_col_dim.padding_high());
562 }
563 
564 // For simplicity, we focus on the column dimension and ignore other dimensions.
565 // We use [?] to represent the shape instead of the content.
566 //
567 // Suppose operator FC does
568 //   [3] = conv([4], [2], padding_low=1, padding_high=-1)
569 //
570 // BC = BackwardInput(FC) does:
571 //   [4] = conv([3], reverse([2]), padding_high=2)
572 //
573 // We currently don't fuse BC because GpuConvPaddingLegalization
574 // doesn't support negative padding on the gradients of backward convolution
575 // (b/32744257).
TEST_F(GpuConvRewriterTest,BackwardInputConvolveNegativePaddingHighOnActivations)576 TEST_F(GpuConvRewriterTest,
577        BackwardInputConvolveNegativePaddingHighOnActivations) {
578   auto builder = HloComputation::Builder(TestName());
579   // The gradients are in NCHW layout.
580   HloInstruction* output =
581       builder.AddInstruction(HloInstruction::CreateParameter(
582           0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
583   // The kernel is in HWIO layout.
584   HloInstruction* kernel =
585       builder.AddInstruction(HloInstruction::CreateParameter(
586           1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
587   HloInstruction* reverse_kernel = builder.AddInstruction(
588       HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
589 
590   Window conv_window = default_conv_window_;
591   WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
592   forward_conv_col_dim->set_size(2);
593   forward_conv_col_dim->set_padding_high(2);
594   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
595       ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
596       /*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
597       tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
598   // Verify the convolution's shape is consistent with ShapeInference.
599   CHECK(ShapeUtil::Compatible(
600       conv->shape(), ShapeInference::InferConvolveShape(
601                          output->shape(), reverse_kernel->shape(),
602                          /*feature_group_count=*/1, /*batch_group_count=*/1,
603                          conv_window, tf_default_dnums_for_backward_input_,
604                          /*preferred_element_type=*/std::nullopt)
605                          .ValueOrDie()));
606 
607   auto module = CreateNewVerifiedModule();
608   HloComputation* entry_computation =
609       module->AddEntryComputation(builder.Build());
610   EXPECT_TRUE(RunPass(module.get()));
611   EXPECT_THAT(
612       entry_computation->root_instruction(),
613       op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0));
614 }
615 
616 // Check that we will materialize a reversed version of a constant in order to
617 // pattern-match a backwards input convolution.
TEST_F(GpuConvRewriterTest,BackwardInputConvolveConstantFilter)618 TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) {
619   Array4D<float> constant_arr(4, 4, 2, 2);
620   constant_arr.FillIota(0);
621   std::string constant_str =
622       LiteralUtil::CreateR4FromArray4D(constant_arr).ToStringWithoutShape();
623 
624   const std::string module_str = absl::StrFormat(R"(
625     HloModule test
626 
627     ENTRY entry_computation {
628       param0 = f32[128,2,16,16]{3,2,1,0} parameter(0)
629       constant = f32[4,4,2,2]{3,2,1,0} constant(%s)
630       ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant),
631           window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2},
632           dim_labels=bf01_01oi->bf01, feature_group_count=1
633     })",
634                                                  constant_str);
635   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
636 
637   EXPECT_TRUE(RunPass(m.get()));
638   EXPECT_THAT(
639       m->entry_computation()->root_instruction(),
640       op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _,
641                                          op::Reverse(op::Constant())),
642                           0));
643 }
644 
TEST_F(GpuConvRewriterTest,TestBackwardFilterPattern)645 TEST_F(GpuConvRewriterTest, TestBackwardFilterPattern) {
646   const std::string module_str = absl::StrFormat(R"(
647     HloModule Test
648 
649     ENTRY Test {
650       input = f32[8,120,256,256] parameter(0)
651       filter = f32[8,120,256,256] parameter(1)
652 
653       ROOT conv = f32[120,120,3,3] convolution(input, filter), window={size=256x256 pad=1_1x1_1}, dim_labels=fb01_io01->fb01
654     })");
655   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
656 
657   EXPECT_TRUE(RunPass(m.get()));
658   EXPECT_THAT(m->entry_computation()->root_instruction(),
659               op::GetTupleElement(
660                   op::CustomCall(kCudnnConvBackwardFilterCallTarget, _, _), 0));
661 }
662 
663 }  // anonymous namespace
664 }  // namespace gpu
665 }  // namespace xla
666