• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
17 
18 #include "tensorflow/compiler/xla/service/call_inliner.h"
19 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/platform/statusor.h"
27 
28 namespace xla {
29 namespace gpu {
30 namespace {
31 
32 namespace m = ::xla::match;
33 
34 class CudnnVectorizeConvolutionsTest : public HloTestBase {
35  protected:
36   // Runs this pass and some cleanup to make pattern-matching easier.
Run(std::pair<int,int> compute_capability,HloModule * module)37   StatusOr<bool> Run(std::pair<int, int> compute_capability,
38                      HloModule* module) {
39     CudnnVectorizeConvolutions pass(se::CudaComputeCapability{
40         compute_capability.first, compute_capability.second});
41     TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module));
42 
43     CallInliner inliner;
44     TF_RETURN_IF_ERROR(RunHloPass(&inliner, module).status());
45 
46     return changed;
47   }
48 };
49 
TEST_F(CudnnVectorizeConvolutionsTest,VectorizeTo4)50 TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4) {
51   auto module = ParseAndReturnVerifiedModule(R"(
52   HloModule TestModule
53 
54   ENTRY TestComputation {
55     input = s8[10,20,30,40] parameter(0)
56     filter = s8[2,2,40,44] parameter(1)
57     ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
58                   window={size=2x2}, dim_labels=b01f_01io->b01f,
59                   custom_call_target="__cudnn$convForward",
60                   backend_config="{bar: 0}"
61   })")
62                     .ValueOrDie();
63   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
64   EXPECT_TRUE(changed);
65 
66   SCOPED_TRACE(module->ToString());
67   auto* root = module->entry_computation()->root_instruction();
68 
69   const HloInstruction* conv = nullptr;
70   ASSERT_THAT(
71       root,
72       GmockMatch(m::Tuple(
73           m::Reshape(m::GetTupleElement(
74                          m::CustomCall(&conv, kCudnnConvForwardCallTarget,
75                                        m::Reshape(m::Parameter(0))
76                                            .WithShape(S8, {10, 20, 30, 10, 4}),
77                                        m::Reshape(m::Parameter(1))
78                                            .WithShape(S8, {2, 2, 10, 4, 44})))
79                          .WithShape(S8, {10, 20, 30, 11, 4})),
80           m::Op())));
81 
82   EXPECT_EQ(conv->raw_backend_config_string(), "{bar: 0}");
83 
84   const ConvolutionDimensionNumbers& dnums =
85       conv->convolution_dimension_numbers();
86   ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
87   ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
88   ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
89 
90   EXPECT_EQ(dnums.input_batch_dimension(), 0);
91   EXPECT_EQ(dnums.input_spatial_dimensions()[0], 1);
92   EXPECT_EQ(dnums.input_spatial_dimensions()[1], 2);
93   EXPECT_EQ(dnums.input_feature_dimension(), 3);
94 
95   EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 0);
96   EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 1);
97   EXPECT_EQ(dnums.kernel_input_feature_dimension(), 2);
98   EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
99 
100   EXPECT_EQ(dnums.output_batch_dimension(), 0);
101   EXPECT_EQ(dnums.output_spatial_dimensions()[0], 1);
102   EXPECT_EQ(dnums.output_spatial_dimensions()[1], 2);
103   EXPECT_EQ(dnums.output_feature_dimension(), 3);
104 }
105 
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4UnsupportedFilterType)106 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4UnsupportedFilterType) {
107   // This test checks that the vectorize pass correctly calls
108   // CudnnSupportsOptimizedIntegerConvolution() which should reject this
109   // convolution because its filter type is f32.
110   auto module = ParseAndReturnVerifiedModule(R"(
111   HloModule TestModule
112 
113   ENTRY TestComputation {
114     input = s8[10,20,30,40] parameter(0)
115     filter = f32[2,2,40,44] parameter(1)
116     ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
117                   window={size=2x2}, dim_labels=b01f_01io->b01f,
118                   custom_call_target="__cudnn$convForward",
119                   backend_config="{bar: 0}"
120   })")
121                     .ValueOrDie();
122   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
123   EXPECT_FALSE(changed);
124 }
125 
TEST_F(CudnnVectorizeConvolutionsTest,VectorizeTo4NCHW)126 TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4NCHW) {
127   auto module = ParseAndReturnVerifiedModule(R"(
128   HloModule TestModule
129 
130   ENTRY TestComputation {
131     input = s8[10,48,20,30] parameter(0)
132     filter = s8[48,44,2,2] parameter(1)
133     ROOT result = (s8[10,44,20,30], u8[0]) custom-call(input, filter),
134                   window={size=2x2}, dim_labels=bf01_io01->bf01,
135                   custom_call_target="__cudnn$convForward"
136   })")
137                     .ValueOrDie();
138   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
139   EXPECT_TRUE(changed);
140 
141   SCOPED_TRACE(module->ToString());
142   auto* root = module->entry_computation()->root_instruction();
143 
144   const HloInstruction* conv = nullptr;
145   ASSERT_THAT(
146       root,
147       GmockMatch(m::Tuple(
148           m::Reshape(m::GetTupleElement(
149                          m::CustomCall(&conv, kCudnnConvForwardCallTarget,
150                                        m::Reshape(m::Parameter(0))
151                                            .WithShape(S8, {10, 12, 4, 20, 30}),
152                                        m::Reshape(m::Parameter(1))
153                                            .WithShape(S8, {12, 4, 44, 2, 2})))
154                          .WithShape(S8, {10, 11, 4, 20, 30})),
155           m::Op())));
156 
157   const ConvolutionDimensionNumbers& dnums =
158       conv->convolution_dimension_numbers();
159   ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
160   ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
161   ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
162 
163   EXPECT_EQ(dnums.input_batch_dimension(), 0);
164   EXPECT_EQ(dnums.input_feature_dimension(), 1);
165   EXPECT_EQ(dnums.input_spatial_dimensions()[0], 3);
166   EXPECT_EQ(dnums.input_spatial_dimensions()[1], 4);
167 
168   EXPECT_EQ(dnums.kernel_input_feature_dimension(), 0);
169   EXPECT_EQ(dnums.kernel_output_feature_dimension(), 2);
170   EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 3);
171   EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 4);
172 
173   EXPECT_EQ(dnums.output_batch_dimension(), 0);
174   EXPECT_EQ(dnums.output_feature_dimension(), 1);
175   EXPECT_EQ(dnums.output_spatial_dimensions()[0], 3);
176   EXPECT_EQ(dnums.output_spatial_dimensions()[1], 4);
177 }
178 
TEST_F(CudnnVectorizeConvolutionsTest,IncrementAllDnums)179 TEST_F(CudnnVectorizeConvolutionsTest, IncrementAllDnums) {
180   auto module = ParseAndReturnVerifiedModule(R"(
181   HloModule TestModule
182 
183   ENTRY TestComputation {
184     input = s8[16,16,16,16] parameter(0)
185     filter = s8[16,16,3,3] parameter(1)
186     ROOT result = (s8[16,16,16,16], u8[0]) custom-call(input, filter),
187                   window={size=2x2}, dim_labels=fb01_i01o->fb01,
188                   custom_call_target="__cudnn$convForward"
189   })")
190                     .ValueOrDie();
191   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
192   EXPECT_TRUE(changed);
193 
194   SCOPED_TRACE(module->ToString());
195   auto* root = module->entry_computation()->root_instruction();
196 
197   const HloInstruction* conv = nullptr;
198   ASSERT_THAT(
199       root,
200       GmockMatch(m::Tuple(
201           m::Reshape(m::GetTupleElement(
202                          m::CustomCall(&conv, kCudnnConvForwardCallTarget,
203                                        m::Reshape(m::Parameter(0))
204                                            .WithShape(S8, {4, 4, 16, 16, 16}),
205                                        m::Reshape(m::Parameter(1))
206                                            .WithShape(S8, {4, 4, 16, 3, 3})))
207                          .WithShape(S8, {4, 4, 16, 16, 16})),
208           m::Op())));
209 
210   const ConvolutionDimensionNumbers& dnums =
211       conv->convolution_dimension_numbers();
212   ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
213   ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
214   ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
215 
216   EXPECT_EQ(dnums.input_feature_dimension(), 0);
217   EXPECT_EQ(dnums.input_batch_dimension(), 2);
218   EXPECT_EQ(dnums.input_spatial_dimensions()[0], 3);
219   EXPECT_EQ(dnums.input_spatial_dimensions()[1], 4);
220 
221   EXPECT_EQ(dnums.kernel_input_feature_dimension(), 0);
222   EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 2);
223   EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 3);
224   EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
225 
226   EXPECT_EQ(dnums.output_feature_dimension(), 0);
227   EXPECT_EQ(dnums.output_batch_dimension(), 2);
228   EXPECT_EQ(dnums.output_spatial_dimensions()[0], 3);
229   EXPECT_EQ(dnums.output_spatial_dimensions()[1], 4);
230 }
231 
TEST_F(CudnnVectorizeConvolutionsTest,FilterDnums)232 TEST_F(CudnnVectorizeConvolutionsTest, FilterDnums) {
233   auto module = ParseAndReturnVerifiedModule(R"(
234   HloModule TestModule
235 
236   ENTRY TestComputation {
237     input = s8[1,20,9,9] parameter(0)
238     filter = s8[3,3,20,32] parameter(1)
239     ROOT result = (s8[1,32,9,9], u8[0]) custom-call(s8[1,20,9,9] input, s8[3,3,20,32] filter),
240                   window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01,
241                   custom_call_target="__cudnn$convForward"
242   })")
243                     .ValueOrDie();
244   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
245   EXPECT_TRUE(changed);
246 
247   SCOPED_TRACE(module->ToString());
248   auto* root = module->entry_computation()->root_instruction();
249 
250   const HloInstruction* conv = nullptr;
251   ASSERT_THAT(
252       root,
253       GmockMatch(m::Tuple(
254           m::Reshape(m::GetTupleElement(
255                          m::CustomCall(&conv, kCudnnConvForwardCallTarget,
256                                        m::Reshape(m::Parameter(0))
257                                            .WithShape(S8, {1, 5, 4, 9, 9}),
258                                        m::Reshape(m::Parameter(1))
259                                            .WithShape(S8, {3, 3, 5, 4, 32})))
260                          .WithShape(S8, {1, 8, 4, 9, 9})),
261           m::Op())));
262 
263   const ConvolutionDimensionNumbers& dnums =
264       conv->convolution_dimension_numbers();
265   ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
266   ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
267   ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
268 
269   EXPECT_EQ(dnums.input_batch_dimension(), 0);
270   EXPECT_EQ(dnums.input_feature_dimension(), 1);
271   EXPECT_EQ(dnums.input_spatial_dimensions()[0], 3);
272   EXPECT_EQ(dnums.input_spatial_dimensions()[1], 4);
273 
274   EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 0);
275   EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 1);
276   EXPECT_EQ(dnums.kernel_input_feature_dimension(), 2);
277   EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
278 
279   EXPECT_EQ(dnums.output_batch_dimension(), 0);
280   EXPECT_EQ(dnums.output_feature_dimension(), 1);
281   EXPECT_EQ(dnums.output_spatial_dimensions()[0], 3);
282   EXPECT_EQ(dnums.output_spatial_dimensions()[1], 4);
283 }
284 
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4)285 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4) {
286   auto module = ParseAndReturnVerifiedModule(R"(
287   HloModule TestModule
288 
289   ENTRY TestComputation {
290     input = s8[10,20,30,41] parameter(0)
291     filter = s8[2,2,41,44] parameter(1)
292     ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
293                   window={size=2x2}, dim_labels=b01f_01io->b01f,
294                   custom_call_target="__cudnn$convForward"
295   })")
296                     .ValueOrDie();
297   CudnnVectorizeConvolutions pass({7, 5});
298   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
299 
300   SCOPED_TRACE(module->ToString());
301   EXPECT_FALSE(changed);
302 }
303 
304 // Don't vectorize int8_t -> int32_t into int8x4 or int8x32; this is not
305 // supported in cudnn.
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4IfOutputIsS32)306 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsS32) {
307   auto module = ParseAndReturnVerifiedModule(R"(
308   HloModule TestModule
309 
310   ENTRY TestComputation {
311     input = s8[10,20,30,41] parameter(0)
312     filter = s8[2,2,41,44] parameter(1)
313     ROOT result = (s32[10,20,30,44], u8[0]) custom-call(input, filter),
314                   window={size=2x2}, dim_labels=b01f_01io->b01f,
315                   custom_call_target="__cudnn$convForward"
316   })")
317                     .ValueOrDie();
318   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
319   SCOPED_TRACE(module->ToString());
320   EXPECT_FALSE(changed);
321 }
322 
323 // Don't vectorize int8_t -> float into int8x4 or int8x32.  Vectorizing to
324 // int8x4 *is* allowed by cudnn, but we don't do it at the moment.
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4IfOutputIsF32)325 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsF32) {
326   auto module = ParseAndReturnVerifiedModule(R"(
327   HloModule TestModule
328 
329   ENTRY TestComputation {
330     input = s8[10,20,30,41] parameter(0)
331     filter = s8[2,2,41,44] parameter(1)
332     ROOT result = (f32[10,20,30,44], u8[0]) custom-call(input, filter),
333                   window={size=2x2}, dim_labels=b01f_01io->b01f,
334                   custom_call_target="__cudnn$convForward"
335   })")
336                     .ValueOrDie();
337   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
338   SCOPED_TRACE(module->ToString());
339   EXPECT_FALSE(changed);
340 }
341 
TEST_F(CudnnVectorizeConvolutionsTest,VectorizeTo32)342 TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo32) {
343   auto module = ParseAndReturnVerifiedModule(R"(
344   HloModule TestModule
345 
346   ENTRY TestComputation {
347     input = s8[10,20,30,64] parameter(0)
348     filter = s8[2,2,64,128] parameter(1)
349     ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
350                   window={size=2x2}, dim_labels=b01f_01io->b01f,
351                   custom_call_target="__cudnn$convForward"
352   })")
353                     .ValueOrDie();
354   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
355   EXPECT_TRUE(changed);
356 
357   SCOPED_TRACE(module->ToString());
358   auto* root = module->entry_computation()->root_instruction();
359 
360   const HloInstruction* conv = nullptr;
361   ASSERT_THAT(
362       root,
363       GmockMatch(m::Tuple(
364           m::Reshape(m::GetTupleElement(
365                          m::CustomCall(&conv, kCudnnConvForwardCallTarget,
366                                        m::Reshape(m::Parameter(0))
367                                            .WithShape(S8, {10, 20, 30, 2, 32}),
368                                        m::Reshape(m::Parameter(1))
369                                            .WithShape(S8, {2, 2, 2, 32, 128})))
370                          .WithShape(S8, {10, 20, 30, 4, 32})),
371           m::Op())));
372 }
373 
TEST_F(CudnnVectorizeConvolutionsTest,BiasAndSideInput)374 TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) {
375   auto module = ParseAndReturnVerifiedModule(R"(
376   HloModule TestModule
377 
378   ENTRY TestComputation {
379     input = s8[10,20,30,64] parameter(0)
380     filter = s8[2,2,64,128] parameter(1)
381     bias = f32[10] parameter(2)
382     side_input = s8[10,20,30,64] parameter(3)
383 
384     ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter, bias, side_input),
385                   window={size=2x2}, dim_labels=b01f_01io->b01f,
386                   custom_call_target="__cudnn$convForward"
387   })")
388                     .ValueOrDie();
389   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
390   EXPECT_TRUE(changed);
391 
392   SCOPED_TRACE(module->ToString());
393   auto* root = module->entry_computation()->root_instruction();
394 
395   const HloInstruction* conv = nullptr;
396   ASSERT_THAT(
397       root,
398       GmockMatch(m::Tuple(
399           m::Reshape(m::GetTupleElement(
400                          m::CustomCall(&conv, kCudnnConvForwardCallTarget,
401                                        m::Reshape(m::Parameter(0))
402                                            .WithShape(S8, {10, 20, 30, 2, 32}),
403                                        m::Reshape(m::Parameter(1))
404                                            .WithShape(S8, {2, 2, 2, 32, 128}),
405                                        m::Parameter(2),
406                                        m::Reshape(m::Parameter(3))
407                                            .WithShape(S8, {10, 20, 30, 2, 32})))
408                          .WithShape(S8, {10, 20, 30, 4, 32})),
409           m::Op())));
410 }
411 
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo32)412 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) {
413   auto module = ParseAndReturnVerifiedModule(R"(
414   HloModule TestModule
415 
416   ENTRY TestComputation {
417     input = s8[10,20,30,64] parameter(0)
418     filter = s8[2,2,64,128] parameter(1)
419     ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
420                   window={size=2x2}, dim_labels=b01f_01io->b01f,
421                   custom_call_target="__cudnn$convForward"
422   })")
423                     .ValueOrDie();
424   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
425   EXPECT_TRUE(changed);
426 
427   SCOPED_TRACE(module->ToString());
428   auto* root = module->entry_computation()->root_instruction();
429 
430   const HloInstruction* conv = nullptr;
431   ASSERT_THAT(
432       root,
433       GmockMatch(m::Tuple(
434           m::Reshape(m::GetTupleElement(
435                          m::CustomCall(&conv, kCudnnConvForwardCallTarget,
436                                        m::Reshape(m::Parameter(0))
437                                            .WithShape(S8, {10, 20, 30, 16, 4}),
438                                        m::Reshape(m::Parameter(1))
439                                            .WithShape(S8, {2, 2, 16, 4, 128})))
440                          .WithShape(S8, {10, 20, 30, 32, 4})),
441           m::Op())));
442 }
443 
TEST_F(CudnnVectorizeConvolutionsTest,Vectorize4To32)444 TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) {
445   auto module = ParseAndReturnVerifiedModule(R"(
446   HloModule TestModule
447 
448   ENTRY TestComputation {
449     input = s8[10,20,30,16,4] parameter(0)
450     filter = s8[3,5,16,192,4] parameter(1)
451     bias = f32[10] parameter(2)
452     side_input = s8[10,20,30,16,4] parameter(3)
453     ROOT result = (s8[10,20,30,48,4], u8[0]) custom-call(input, filter, bias, side_input),
454                   window={size=3x5}, dim_labels=b01f_01io->b01f,
455                   custom_call_target="__cudnn$convForward",
456                   backend_config="{foo: 42}"
457   })")
458                     .ValueOrDie();
459   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
460   EXPECT_TRUE(changed);
461 
462   SCOPED_TRACE(module->ToString());
463   auto* root = module->entry_computation()->root_instruction();
464 
465   const HloInstruction* conv = nullptr;
466   auto conv_pat =
467       m::GetTupleElement(
468           m::CustomCall(
469               &conv, kCudnnConvForwardCallTarget,
470               m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
471                                           .WithShape(S8, {10, 20, 30, 2, 8, 4}))
472                              .WithShape(S8, {10, 20, 30, 2, 8, 4}))
473                   .WithShape(S8, {10, 20, 30, 2, 32}),
474               m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))
475                                           .WithShape(S8, {3, 5, 2, 8, 192, 4}))
476                              .WithShape(S8, {3, 5, 2, 192, 8, 4}))
477                   .WithShape(S8, {3, 5, 2, 192, 32}),
478               m::Parameter(2),
479               m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
480                                           .WithShape(S8, {10, 20, 30, 2, 8, 4}))
481                              .WithShape(S8, {10, 20, 30, 2, 8, 4}))
482                   .WithShape(S8, {10, 20, 30, 2, 32})))
483           .WithShape(S8, {10, 20, 30, 6, 32});
484   ASSERT_THAT(root, GmockMatch(m::Tuple(
485                         m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
486                                                     S8, {10, 20, 30, 6, 8, 4}))
487                                        .WithShape(S8, {10, 20, 30, 6, 8, 4}))
488                             .WithShape(S8, {10, 20, 30, 48, 4}),
489                         m::Op())));
490 
491   EXPECT_EQ(conv->raw_backend_config_string(), "{foo: 42}");
492 
493   const ConvolutionDimensionNumbers& dnums =
494       conv->convolution_dimension_numbers();
495   ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
496   ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
497   ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
498 
499   EXPECT_EQ(dnums.input_batch_dimension(), 0);
500   EXPECT_EQ(dnums.input_spatial_dimensions()[0], 1);
501   EXPECT_EQ(dnums.input_spatial_dimensions()[1], 2);
502   EXPECT_EQ(dnums.input_feature_dimension(), 3);
503 
504   EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 0);
505   EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 1);
506   EXPECT_EQ(dnums.kernel_input_feature_dimension(), 2);
507   EXPECT_EQ(dnums.kernel_output_feature_dimension(), 3);
508 
509   EXPECT_EQ(dnums.output_batch_dimension(), 0);
510   EXPECT_EQ(dnums.output_spatial_dimensions()[0], 1);
511   EXPECT_EQ(dnums.output_spatial_dimensions()[1], 2);
512   EXPECT_EQ(dnums.output_feature_dimension(), 3);
513 }
514 
TEST_F(CudnnVectorizeConvolutionsTest,Vectorize4To32NCHW)515 TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) {
516   auto module = ParseAndReturnVerifiedModule(R"(
517   HloModule TestModule
518 
519   ENTRY TestComputation {
520     input = s8[10,16,20,30,4] parameter(0)
521     filter = s8[16,128,2,2,4] parameter(1)
522     bias = f32[10] parameter(2)
523     side_input = s8[10,16,20,30,4] parameter(3)
524     ROOT result = (s8[10,32,20,30,4], u8[0]) custom-call(input, filter, bias, side_input),
525                   window={size=2x2}, dim_labels=bf01_io01->bf01,
526                   custom_call_target="__cudnn$convForward"
527   })")
528                     .ValueOrDie();
529   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
530   EXPECT_TRUE(changed);
531 
532   SCOPED_TRACE(module->ToString());
533   auto* root = module->entry_computation()->root_instruction();
534 
535   const HloInstruction* conv = nullptr;
536   auto conv_pat =
537       m::GetTupleElement(
538           m::CustomCall(
539               &conv, kCudnnConvForwardCallTarget,
540               m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
541                                           .WithShape(S8, {10, 2, 8, 20, 30, 4}))
542                              .WithShape(S8, {10, 2, 20, 30, 8, 4}))
543                   .WithShape(S8, {10, 2, 20, 30, 32}),
544               m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))
545                                           .WithShape(S8, {2, 8, 128, 2, 2, 4}))
546                              .WithShape(S8, {2, 128, 2, 2, 8, 4}))
547                   .WithShape(S8, {2, 128, 2, 2, 32}),
548               m::Parameter(2),
549               m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
550                                           .WithShape(S8, {10, 2, 8, 20, 30, 4}))
551                              .WithShape(S8, {10, 2, 20, 30, 8, 4}))
552                   .WithShape(S8, {10, 2, 20, 30, 32})))
553           .WithShape(S8, {10, 4, 20, 30, 32});
554   ASSERT_THAT(root, GmockMatch(m::Tuple(
555                         m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
556                                                     S8, {10, 4, 20, 30, 8, 4}))
557                                        .WithShape(S8, {10, 4, 8, 20, 30, 4}))
558                             .WithShape(S8, {10, 32, 20, 30, 4}),
559                         m::Op())));
560 
561   const ConvolutionDimensionNumbers& dnums =
562       conv->convolution_dimension_numbers();
563   ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
564   ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
565   ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
566 
567   EXPECT_EQ(dnums.input_batch_dimension(), 0);
568   EXPECT_EQ(dnums.input_feature_dimension(), 1);
569   EXPECT_EQ(dnums.input_spatial_dimensions()[0], 2);
570   EXPECT_EQ(dnums.input_spatial_dimensions()[1], 3);
571 
572   EXPECT_EQ(dnums.kernel_input_feature_dimension(), 0);
573   EXPECT_EQ(dnums.kernel_output_feature_dimension(), 1);
574   EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 2);
575   EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 3);
576 
577   EXPECT_EQ(dnums.output_batch_dimension(), 0);
578   EXPECT_EQ(dnums.output_feature_dimension(), 1);
579   EXPECT_EQ(dnums.output_spatial_dimensions()[0], 2);
580   EXPECT_EQ(dnums.output_spatial_dimensions()[1], 3);
581 }
582 
TEST_F(CudnnVectorizeConvolutionsTest,Vectorize4To32VectorDimFirst)583 TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) {
584   auto module = ParseAndReturnVerifiedModule(R"(
585   HloModule TestModule
586 
587   ENTRY TestComputation {
588     input = s8[4,10,20,30,16] parameter(0)
589     filter = s8[4,3,5,16,192] parameter(1)
590     bias = f32[10] parameter(2)
591     side_input = s8[4,10,20,30,16] parameter(3)
592     ROOT result = (s8[4,10,20,30,48], u8[0]) custom-call(input, filter, bias, side_input),
593                   window={size=3x5}, dim_labels=?b01f_?01io->?b01f,
594                   custom_call_target="__cudnn$convForward"
595   })")
596                     .ValueOrDie();
597   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
598   EXPECT_TRUE(changed);
599 
600   SCOPED_TRACE(module->ToString());
601   auto* root = module->entry_computation()->root_instruction();
602 
603   const HloInstruction* conv = nullptr;
604   auto conv_pat =
605       m::GetTupleElement(
606           m::CustomCall(
607               &conv, kCudnnConvForwardCallTarget,
608               m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
609                                           .WithShape(S8, {4, 10, 20, 30, 2, 8}))
610                              .WithShape(S8, {8, 4, 10, 20, 30, 2}))
611                   .WithShape(S8, {32, 10, 20, 30, 2}),
612               m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))
613                                           .WithShape(S8, {4, 3, 5, 2, 8, 192}))
614                              .WithShape(S8, {8, 4, 3, 5, 2, 192}))
615                   .WithShape(S8, {32, 3, 5, 2, 192}),
616               m::Parameter(2),
617               m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
618                                           .WithShape(S8, {4, 10, 20, 30, 2, 8}))
619                              .WithShape(S8, {8, 4, 10, 20, 30, 2}))
620                   .WithShape(S8, {32, 10, 20, 30, 2})))
621           .WithShape(S8, {32, 10, 20, 30, 6});
622   ASSERT_THAT(root, GmockMatch(m::Tuple(
623                         m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
624                                                     S8, {8, 4, 10, 20, 30, 6}))
625                                        .WithShape(S8, {4, 10, 20, 30, 6, 8}))
626                             .WithShape(S8, {4, 10, 20, 30, 48}),
627                         m::Op())));
628 
629   const ConvolutionDimensionNumbers& dnums =
630       conv->convolution_dimension_numbers();
631   ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
632   ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
633   ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
634 
635   EXPECT_EQ(dnums.input_batch_dimension(), 1);
636   EXPECT_EQ(dnums.input_spatial_dimensions()[0], 2);
637   EXPECT_EQ(dnums.input_spatial_dimensions()[1], 3);
638   EXPECT_EQ(dnums.input_feature_dimension(), 4);
639 
640   EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 1);
641   EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 2);
642   EXPECT_EQ(dnums.kernel_input_feature_dimension(), 3);
643   EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
644 
645   EXPECT_EQ(dnums.output_batch_dimension(), 1);
646   EXPECT_EQ(dnums.output_spatial_dimensions()[0], 2);
647   EXPECT_EQ(dnums.output_spatial_dimensions()[1], 3);
648   EXPECT_EQ(dnums.output_feature_dimension(), 4);
649 }
650 
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorize4To32)651 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorize4To32) {
652   auto module = ParseAndReturnVerifiedModule(R"(
653   HloModule TestModule
654 
655   ENTRY TestComputation {
656     input = s8[10,20,30,16,4] parameter(0)
657     filter = s8[2,2,16,128,4] parameter(1)
658     bias = f32[10] parameter(2)
659     side_input = s8[10,20,30,16,4] parameter(3)
660     ROOT result = (s8[10,20,30,32,4], u8[0]) custom-call(input, filter, bias, side_input),
661                   window={size=2x2}, dim_labels=b01f_01io->b01f,
662                   custom_call_target="__cudnn$convForward"
663   })")
664                     .ValueOrDie();
665   TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
666   EXPECT_FALSE(changed);
667 }
668 
669 }  // namespace
670 }  // namespace gpu
671 }  // namespace xla
672