1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
17
18 #include "tensorflow/compiler/xla/service/call_inliner.h"
19 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
22 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/platform/statusor.h"
27
28 namespace xla {
29 namespace gpu {
30 namespace {
31
32 namespace m = ::xla::match;
33
34 class CudnnVectorizeConvolutionsTest : public HloTestBase {
35 protected:
36 // Runs this pass and some cleanup to make pattern-matching easier.
Run(std::pair<int,int> compute_capability,HloModule * module)37 StatusOr<bool> Run(std::pair<int, int> compute_capability,
38 HloModule* module) {
39 CudnnVectorizeConvolutions pass(se::CudaComputeCapability{
40 compute_capability.first, compute_capability.second});
41 TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(&pass, module));
42
43 CallInliner inliner;
44 TF_RETURN_IF_ERROR(RunHloPass(&inliner, module).status());
45
46 return changed;
47 }
48 };
49
TEST_F(CudnnVectorizeConvolutionsTest,VectorizeTo4)50 TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4) {
51 auto module = ParseAndReturnVerifiedModule(R"(
52 HloModule TestModule
53
54 ENTRY TestComputation {
55 input = s8[10,20,30,40] parameter(0)
56 filter = s8[2,2,40,44] parameter(1)
57 ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
58 window={size=2x2}, dim_labels=b01f_01io->b01f,
59 custom_call_target="__cudnn$convForward",
60 backend_config="{bar: 0}"
61 })")
62 .ValueOrDie();
63 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
64 EXPECT_TRUE(changed);
65
66 SCOPED_TRACE(module->ToString());
67 auto* root = module->entry_computation()->root_instruction();
68
69 const HloInstruction* conv = nullptr;
70 ASSERT_THAT(
71 root,
72 GmockMatch(m::Tuple(
73 m::Reshape(m::GetTupleElement(
74 m::CustomCall(&conv, kCudnnConvForwardCallTarget,
75 m::Reshape(m::Parameter(0))
76 .WithShape(S8, {10, 20, 30, 10, 4}),
77 m::Reshape(m::Parameter(1))
78 .WithShape(S8, {2, 2, 10, 4, 44})))
79 .WithShape(S8, {10, 20, 30, 11, 4})),
80 m::Op())));
81
82 EXPECT_EQ(conv->raw_backend_config_string(), "{bar: 0}");
83
84 const ConvolutionDimensionNumbers& dnums =
85 conv->convolution_dimension_numbers();
86 ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
87 ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
88 ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
89
90 EXPECT_EQ(dnums.input_batch_dimension(), 0);
91 EXPECT_EQ(dnums.input_spatial_dimensions()[0], 1);
92 EXPECT_EQ(dnums.input_spatial_dimensions()[1], 2);
93 EXPECT_EQ(dnums.input_feature_dimension(), 3);
94
95 EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 0);
96 EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 1);
97 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 2);
98 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
99
100 EXPECT_EQ(dnums.output_batch_dimension(), 0);
101 EXPECT_EQ(dnums.output_spatial_dimensions()[0], 1);
102 EXPECT_EQ(dnums.output_spatial_dimensions()[1], 2);
103 EXPECT_EQ(dnums.output_feature_dimension(), 3);
104 }
105
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4UnsupportedFilterType)106 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4UnsupportedFilterType) {
107 // This test checks that the vectorize pass correctly calls
108 // CudnnSupportsOptimizedIntegerConvolution() which should reject this
109 // convolution because its filter type is f32.
110 auto module = ParseAndReturnVerifiedModule(R"(
111 HloModule TestModule
112
113 ENTRY TestComputation {
114 input = s8[10,20,30,40] parameter(0)
115 filter = f32[2,2,40,44] parameter(1)
116 ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
117 window={size=2x2}, dim_labels=b01f_01io->b01f,
118 custom_call_target="__cudnn$convForward",
119 backend_config="{bar: 0}"
120 })")
121 .ValueOrDie();
122 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
123 EXPECT_FALSE(changed);
124 }
125
TEST_F(CudnnVectorizeConvolutionsTest,VectorizeTo4NCHW)126 TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo4NCHW) {
127 auto module = ParseAndReturnVerifiedModule(R"(
128 HloModule TestModule
129
130 ENTRY TestComputation {
131 input = s8[10,48,20,30] parameter(0)
132 filter = s8[48,44,2,2] parameter(1)
133 ROOT result = (s8[10,44,20,30], u8[0]) custom-call(input, filter),
134 window={size=2x2}, dim_labels=bf01_io01->bf01,
135 custom_call_target="__cudnn$convForward"
136 })")
137 .ValueOrDie();
138 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
139 EXPECT_TRUE(changed);
140
141 SCOPED_TRACE(module->ToString());
142 auto* root = module->entry_computation()->root_instruction();
143
144 const HloInstruction* conv = nullptr;
145 ASSERT_THAT(
146 root,
147 GmockMatch(m::Tuple(
148 m::Reshape(m::GetTupleElement(
149 m::CustomCall(&conv, kCudnnConvForwardCallTarget,
150 m::Reshape(m::Parameter(0))
151 .WithShape(S8, {10, 12, 4, 20, 30}),
152 m::Reshape(m::Parameter(1))
153 .WithShape(S8, {12, 4, 44, 2, 2})))
154 .WithShape(S8, {10, 11, 4, 20, 30})),
155 m::Op())));
156
157 const ConvolutionDimensionNumbers& dnums =
158 conv->convolution_dimension_numbers();
159 ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
160 ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
161 ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
162
163 EXPECT_EQ(dnums.input_batch_dimension(), 0);
164 EXPECT_EQ(dnums.input_feature_dimension(), 1);
165 EXPECT_EQ(dnums.input_spatial_dimensions()[0], 3);
166 EXPECT_EQ(dnums.input_spatial_dimensions()[1], 4);
167
168 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 0);
169 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 2);
170 EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 3);
171 EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 4);
172
173 EXPECT_EQ(dnums.output_batch_dimension(), 0);
174 EXPECT_EQ(dnums.output_feature_dimension(), 1);
175 EXPECT_EQ(dnums.output_spatial_dimensions()[0], 3);
176 EXPECT_EQ(dnums.output_spatial_dimensions()[1], 4);
177 }
178
TEST_F(CudnnVectorizeConvolutionsTest,IncrementAllDnums)179 TEST_F(CudnnVectorizeConvolutionsTest, IncrementAllDnums) {
180 auto module = ParseAndReturnVerifiedModule(R"(
181 HloModule TestModule
182
183 ENTRY TestComputation {
184 input = s8[16,16,16,16] parameter(0)
185 filter = s8[16,16,3,3] parameter(1)
186 ROOT result = (s8[16,16,16,16], u8[0]) custom-call(input, filter),
187 window={size=2x2}, dim_labels=fb01_i01o->fb01,
188 custom_call_target="__cudnn$convForward"
189 })")
190 .ValueOrDie();
191 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
192 EXPECT_TRUE(changed);
193
194 SCOPED_TRACE(module->ToString());
195 auto* root = module->entry_computation()->root_instruction();
196
197 const HloInstruction* conv = nullptr;
198 ASSERT_THAT(
199 root,
200 GmockMatch(m::Tuple(
201 m::Reshape(m::GetTupleElement(
202 m::CustomCall(&conv, kCudnnConvForwardCallTarget,
203 m::Reshape(m::Parameter(0))
204 .WithShape(S8, {4, 4, 16, 16, 16}),
205 m::Reshape(m::Parameter(1))
206 .WithShape(S8, {4, 4, 16, 3, 3})))
207 .WithShape(S8, {4, 4, 16, 16, 16})),
208 m::Op())));
209
210 const ConvolutionDimensionNumbers& dnums =
211 conv->convolution_dimension_numbers();
212 ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
213 ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
214 ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
215
216 EXPECT_EQ(dnums.input_feature_dimension(), 0);
217 EXPECT_EQ(dnums.input_batch_dimension(), 2);
218 EXPECT_EQ(dnums.input_spatial_dimensions()[0], 3);
219 EXPECT_EQ(dnums.input_spatial_dimensions()[1], 4);
220
221 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 0);
222 EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 2);
223 EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 3);
224 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
225
226 EXPECT_EQ(dnums.output_feature_dimension(), 0);
227 EXPECT_EQ(dnums.output_batch_dimension(), 2);
228 EXPECT_EQ(dnums.output_spatial_dimensions()[0], 3);
229 EXPECT_EQ(dnums.output_spatial_dimensions()[1], 4);
230 }
231
TEST_F(CudnnVectorizeConvolutionsTest,FilterDnums)232 TEST_F(CudnnVectorizeConvolutionsTest, FilterDnums) {
233 auto module = ParseAndReturnVerifiedModule(R"(
234 HloModule TestModule
235
236 ENTRY TestComputation {
237 input = s8[1,20,9,9] parameter(0)
238 filter = s8[3,3,20,32] parameter(1)
239 ROOT result = (s8[1,32,9,9], u8[0]) custom-call(s8[1,20,9,9] input, s8[3,3,20,32] filter),
240 window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01,
241 custom_call_target="__cudnn$convForward"
242 })")
243 .ValueOrDie();
244 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
245 EXPECT_TRUE(changed);
246
247 SCOPED_TRACE(module->ToString());
248 auto* root = module->entry_computation()->root_instruction();
249
250 const HloInstruction* conv = nullptr;
251 ASSERT_THAT(
252 root,
253 GmockMatch(m::Tuple(
254 m::Reshape(m::GetTupleElement(
255 m::CustomCall(&conv, kCudnnConvForwardCallTarget,
256 m::Reshape(m::Parameter(0))
257 .WithShape(S8, {1, 5, 4, 9, 9}),
258 m::Reshape(m::Parameter(1))
259 .WithShape(S8, {3, 3, 5, 4, 32})))
260 .WithShape(S8, {1, 8, 4, 9, 9})),
261 m::Op())));
262
263 const ConvolutionDimensionNumbers& dnums =
264 conv->convolution_dimension_numbers();
265 ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
266 ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
267 ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
268
269 EXPECT_EQ(dnums.input_batch_dimension(), 0);
270 EXPECT_EQ(dnums.input_feature_dimension(), 1);
271 EXPECT_EQ(dnums.input_spatial_dimensions()[0], 3);
272 EXPECT_EQ(dnums.input_spatial_dimensions()[1], 4);
273
274 EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 0);
275 EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 1);
276 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 2);
277 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
278
279 EXPECT_EQ(dnums.output_batch_dimension(), 0);
280 EXPECT_EQ(dnums.output_feature_dimension(), 1);
281 EXPECT_EQ(dnums.output_spatial_dimensions()[0], 3);
282 EXPECT_EQ(dnums.output_spatial_dimensions()[1], 4);
283 }
284
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4)285 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4) {
286 auto module = ParseAndReturnVerifiedModule(R"(
287 HloModule TestModule
288
289 ENTRY TestComputation {
290 input = s8[10,20,30,41] parameter(0)
291 filter = s8[2,2,41,44] parameter(1)
292 ROOT result = (s8[10,20,30,44], u8[0]) custom-call(input, filter),
293 window={size=2x2}, dim_labels=b01f_01io->b01f,
294 custom_call_target="__cudnn$convForward"
295 })")
296 .ValueOrDie();
297 CudnnVectorizeConvolutions pass({7, 5});
298 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
299
300 SCOPED_TRACE(module->ToString());
301 EXPECT_FALSE(changed);
302 }
303
304 // Don't vectorize int8_t -> int32_t into int8x4 or int8x32; this is not
305 // supported in cudnn.
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4IfOutputIsS32)306 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsS32) {
307 auto module = ParseAndReturnVerifiedModule(R"(
308 HloModule TestModule
309
310 ENTRY TestComputation {
311 input = s8[10,20,30,41] parameter(0)
312 filter = s8[2,2,41,44] parameter(1)
313 ROOT result = (s32[10,20,30,44], u8[0]) custom-call(input, filter),
314 window={size=2x2}, dim_labels=b01f_01io->b01f,
315 custom_call_target="__cudnn$convForward"
316 })")
317 .ValueOrDie();
318 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
319 SCOPED_TRACE(module->ToString());
320 EXPECT_FALSE(changed);
321 }
322
323 // Don't vectorize int8_t -> float into int8x4 or int8x32. Vectorizing to
324 // int8x4 *is* allowed by cudnn, but we don't do it at the moment.
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo4IfOutputIsF32)325 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo4IfOutputIsF32) {
326 auto module = ParseAndReturnVerifiedModule(R"(
327 HloModule TestModule
328
329 ENTRY TestComputation {
330 input = s8[10,20,30,41] parameter(0)
331 filter = s8[2,2,41,44] parameter(1)
332 ROOT result = (f32[10,20,30,44], u8[0]) custom-call(input, filter),
333 window={size=2x2}, dim_labels=b01f_01io->b01f,
334 custom_call_target="__cudnn$convForward"
335 })")
336 .ValueOrDie();
337 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
338 SCOPED_TRACE(module->ToString());
339 EXPECT_FALSE(changed);
340 }
341
TEST_F(CudnnVectorizeConvolutionsTest,VectorizeTo32)342 TEST_F(CudnnVectorizeConvolutionsTest, VectorizeTo32) {
343 auto module = ParseAndReturnVerifiedModule(R"(
344 HloModule TestModule
345
346 ENTRY TestComputation {
347 input = s8[10,20,30,64] parameter(0)
348 filter = s8[2,2,64,128] parameter(1)
349 ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
350 window={size=2x2}, dim_labels=b01f_01io->b01f,
351 custom_call_target="__cudnn$convForward"
352 })")
353 .ValueOrDie();
354 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
355 EXPECT_TRUE(changed);
356
357 SCOPED_TRACE(module->ToString());
358 auto* root = module->entry_computation()->root_instruction();
359
360 const HloInstruction* conv = nullptr;
361 ASSERT_THAT(
362 root,
363 GmockMatch(m::Tuple(
364 m::Reshape(m::GetTupleElement(
365 m::CustomCall(&conv, kCudnnConvForwardCallTarget,
366 m::Reshape(m::Parameter(0))
367 .WithShape(S8, {10, 20, 30, 2, 32}),
368 m::Reshape(m::Parameter(1))
369 .WithShape(S8, {2, 2, 2, 32, 128})))
370 .WithShape(S8, {10, 20, 30, 4, 32})),
371 m::Op())));
372 }
373
TEST_F(CudnnVectorizeConvolutionsTest,BiasAndSideInput)374 TEST_F(CudnnVectorizeConvolutionsTest, BiasAndSideInput) {
375 auto module = ParseAndReturnVerifiedModule(R"(
376 HloModule TestModule
377
378 ENTRY TestComputation {
379 input = s8[10,20,30,64] parameter(0)
380 filter = s8[2,2,64,128] parameter(1)
381 bias = f32[10] parameter(2)
382 side_input = s8[10,20,30,64] parameter(3)
383
384 ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter, bias, side_input),
385 window={size=2x2}, dim_labels=b01f_01io->b01f,
386 custom_call_target="__cudnn$convForward"
387 })")
388 .ValueOrDie();
389 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
390 EXPECT_TRUE(changed);
391
392 SCOPED_TRACE(module->ToString());
393 auto* root = module->entry_computation()->root_instruction();
394
395 const HloInstruction* conv = nullptr;
396 ASSERT_THAT(
397 root,
398 GmockMatch(m::Tuple(
399 m::Reshape(m::GetTupleElement(
400 m::CustomCall(&conv, kCudnnConvForwardCallTarget,
401 m::Reshape(m::Parameter(0))
402 .WithShape(S8, {10, 20, 30, 2, 32}),
403 m::Reshape(m::Parameter(1))
404 .WithShape(S8, {2, 2, 2, 32, 128}),
405 m::Parameter(2),
406 m::Reshape(m::Parameter(3))
407 .WithShape(S8, {10, 20, 30, 2, 32})))
408 .WithShape(S8, {10, 20, 30, 4, 32})),
409 m::Op())));
410 }
411
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorizeTo32)412 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorizeTo32) {
413 auto module = ParseAndReturnVerifiedModule(R"(
414 HloModule TestModule
415
416 ENTRY TestComputation {
417 input = s8[10,20,30,64] parameter(0)
418 filter = s8[2,2,64,128] parameter(1)
419 ROOT result = (s8[10,20,30,128], u8[0]) custom-call(input, filter),
420 window={size=2x2}, dim_labels=b01f_01io->b01f,
421 custom_call_target="__cudnn$convForward"
422 })")
423 .ValueOrDie();
424 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
425 EXPECT_TRUE(changed);
426
427 SCOPED_TRACE(module->ToString());
428 auto* root = module->entry_computation()->root_instruction();
429
430 const HloInstruction* conv = nullptr;
431 ASSERT_THAT(
432 root,
433 GmockMatch(m::Tuple(
434 m::Reshape(m::GetTupleElement(
435 m::CustomCall(&conv, kCudnnConvForwardCallTarget,
436 m::Reshape(m::Parameter(0))
437 .WithShape(S8, {10, 20, 30, 16, 4}),
438 m::Reshape(m::Parameter(1))
439 .WithShape(S8, {2, 2, 16, 4, 128})))
440 .WithShape(S8, {10, 20, 30, 32, 4})),
441 m::Op())));
442 }
443
TEST_F(CudnnVectorizeConvolutionsTest,Vectorize4To32)444 TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32) {
445 auto module = ParseAndReturnVerifiedModule(R"(
446 HloModule TestModule
447
448 ENTRY TestComputation {
449 input = s8[10,20,30,16,4] parameter(0)
450 filter = s8[3,5,16,192,4] parameter(1)
451 bias = f32[10] parameter(2)
452 side_input = s8[10,20,30,16,4] parameter(3)
453 ROOT result = (s8[10,20,30,48,4], u8[0]) custom-call(input, filter, bias, side_input),
454 window={size=3x5}, dim_labels=b01f_01io->b01f,
455 custom_call_target="__cudnn$convForward",
456 backend_config="{foo: 42}"
457 })")
458 .ValueOrDie();
459 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
460 EXPECT_TRUE(changed);
461
462 SCOPED_TRACE(module->ToString());
463 auto* root = module->entry_computation()->root_instruction();
464
465 const HloInstruction* conv = nullptr;
466 auto conv_pat =
467 m::GetTupleElement(
468 m::CustomCall(
469 &conv, kCudnnConvForwardCallTarget,
470 m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
471 .WithShape(S8, {10, 20, 30, 2, 8, 4}))
472 .WithShape(S8, {10, 20, 30, 2, 8, 4}))
473 .WithShape(S8, {10, 20, 30, 2, 32}),
474 m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))
475 .WithShape(S8, {3, 5, 2, 8, 192, 4}))
476 .WithShape(S8, {3, 5, 2, 192, 8, 4}))
477 .WithShape(S8, {3, 5, 2, 192, 32}),
478 m::Parameter(2),
479 m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
480 .WithShape(S8, {10, 20, 30, 2, 8, 4}))
481 .WithShape(S8, {10, 20, 30, 2, 8, 4}))
482 .WithShape(S8, {10, 20, 30, 2, 32})))
483 .WithShape(S8, {10, 20, 30, 6, 32});
484 ASSERT_THAT(root, GmockMatch(m::Tuple(
485 m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
486 S8, {10, 20, 30, 6, 8, 4}))
487 .WithShape(S8, {10, 20, 30, 6, 8, 4}))
488 .WithShape(S8, {10, 20, 30, 48, 4}),
489 m::Op())));
490
491 EXPECT_EQ(conv->raw_backend_config_string(), "{foo: 42}");
492
493 const ConvolutionDimensionNumbers& dnums =
494 conv->convolution_dimension_numbers();
495 ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
496 ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
497 ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
498
499 EXPECT_EQ(dnums.input_batch_dimension(), 0);
500 EXPECT_EQ(dnums.input_spatial_dimensions()[0], 1);
501 EXPECT_EQ(dnums.input_spatial_dimensions()[1], 2);
502 EXPECT_EQ(dnums.input_feature_dimension(), 3);
503
504 EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 0);
505 EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 1);
506 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 2);
507 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 3);
508
509 EXPECT_EQ(dnums.output_batch_dimension(), 0);
510 EXPECT_EQ(dnums.output_spatial_dimensions()[0], 1);
511 EXPECT_EQ(dnums.output_spatial_dimensions()[1], 2);
512 EXPECT_EQ(dnums.output_feature_dimension(), 3);
513 }
514
TEST_F(CudnnVectorizeConvolutionsTest,Vectorize4To32NCHW)515 TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32NCHW) {
516 auto module = ParseAndReturnVerifiedModule(R"(
517 HloModule TestModule
518
519 ENTRY TestComputation {
520 input = s8[10,16,20,30,4] parameter(0)
521 filter = s8[16,128,2,2,4] parameter(1)
522 bias = f32[10] parameter(2)
523 side_input = s8[10,16,20,30,4] parameter(3)
524 ROOT result = (s8[10,32,20,30,4], u8[0]) custom-call(input, filter, bias, side_input),
525 window={size=2x2}, dim_labels=bf01_io01->bf01,
526 custom_call_target="__cudnn$convForward"
527 })")
528 .ValueOrDie();
529 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
530 EXPECT_TRUE(changed);
531
532 SCOPED_TRACE(module->ToString());
533 auto* root = module->entry_computation()->root_instruction();
534
535 const HloInstruction* conv = nullptr;
536 auto conv_pat =
537 m::GetTupleElement(
538 m::CustomCall(
539 &conv, kCudnnConvForwardCallTarget,
540 m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
541 .WithShape(S8, {10, 2, 8, 20, 30, 4}))
542 .WithShape(S8, {10, 2, 20, 30, 8, 4}))
543 .WithShape(S8, {10, 2, 20, 30, 32}),
544 m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))
545 .WithShape(S8, {2, 8, 128, 2, 2, 4}))
546 .WithShape(S8, {2, 128, 2, 2, 8, 4}))
547 .WithShape(S8, {2, 128, 2, 2, 32}),
548 m::Parameter(2),
549 m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
550 .WithShape(S8, {10, 2, 8, 20, 30, 4}))
551 .WithShape(S8, {10, 2, 20, 30, 8, 4}))
552 .WithShape(S8, {10, 2, 20, 30, 32})))
553 .WithShape(S8, {10, 4, 20, 30, 32});
554 ASSERT_THAT(root, GmockMatch(m::Tuple(
555 m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
556 S8, {10, 4, 20, 30, 8, 4}))
557 .WithShape(S8, {10, 4, 8, 20, 30, 4}))
558 .WithShape(S8, {10, 32, 20, 30, 4}),
559 m::Op())));
560
561 const ConvolutionDimensionNumbers& dnums =
562 conv->convolution_dimension_numbers();
563 ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
564 ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
565 ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
566
567 EXPECT_EQ(dnums.input_batch_dimension(), 0);
568 EXPECT_EQ(dnums.input_feature_dimension(), 1);
569 EXPECT_EQ(dnums.input_spatial_dimensions()[0], 2);
570 EXPECT_EQ(dnums.input_spatial_dimensions()[1], 3);
571
572 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 0);
573 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 1);
574 EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 2);
575 EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 3);
576
577 EXPECT_EQ(dnums.output_batch_dimension(), 0);
578 EXPECT_EQ(dnums.output_feature_dimension(), 1);
579 EXPECT_EQ(dnums.output_spatial_dimensions()[0], 2);
580 EXPECT_EQ(dnums.output_spatial_dimensions()[1], 3);
581 }
582
TEST_F(CudnnVectorizeConvolutionsTest,Vectorize4To32VectorDimFirst)583 TEST_F(CudnnVectorizeConvolutionsTest, Vectorize4To32VectorDimFirst) {
584 auto module = ParseAndReturnVerifiedModule(R"(
585 HloModule TestModule
586
587 ENTRY TestComputation {
588 input = s8[4,10,20,30,16] parameter(0)
589 filter = s8[4,3,5,16,192] parameter(1)
590 bias = f32[10] parameter(2)
591 side_input = s8[4,10,20,30,16] parameter(3)
592 ROOT result = (s8[4,10,20,30,48], u8[0]) custom-call(input, filter, bias, side_input),
593 window={size=3x5}, dim_labels=?b01f_?01io->?b01f,
594 custom_call_target="__cudnn$convForward"
595 })")
596 .ValueOrDie();
597 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 5}, module.get()));
598 EXPECT_TRUE(changed);
599
600 SCOPED_TRACE(module->ToString());
601 auto* root = module->entry_computation()->root_instruction();
602
603 const HloInstruction* conv = nullptr;
604 auto conv_pat =
605 m::GetTupleElement(
606 m::CustomCall(
607 &conv, kCudnnConvForwardCallTarget,
608 m::Reshape(m::Transpose(m::Reshape(m::Parameter(0))
609 .WithShape(S8, {4, 10, 20, 30, 2, 8}))
610 .WithShape(S8, {8, 4, 10, 20, 30, 2}))
611 .WithShape(S8, {32, 10, 20, 30, 2}),
612 m::Reshape(m::Transpose(m::Reshape(m::Parameter(1))
613 .WithShape(S8, {4, 3, 5, 2, 8, 192}))
614 .WithShape(S8, {8, 4, 3, 5, 2, 192}))
615 .WithShape(S8, {32, 3, 5, 2, 192}),
616 m::Parameter(2),
617 m::Reshape(m::Transpose(m::Reshape(m::Parameter(3))
618 .WithShape(S8, {4, 10, 20, 30, 2, 8}))
619 .WithShape(S8, {8, 4, 10, 20, 30, 2}))
620 .WithShape(S8, {32, 10, 20, 30, 2})))
621 .WithShape(S8, {32, 10, 20, 30, 6});
622 ASSERT_THAT(root, GmockMatch(m::Tuple(
623 m::Reshape(m::Transpose(m::Reshape(conv_pat).WithShape(
624 S8, {8, 4, 10, 20, 30, 6}))
625 .WithShape(S8, {4, 10, 20, 30, 6, 8}))
626 .WithShape(S8, {4, 10, 20, 30, 48}),
627 m::Op())));
628
629 const ConvolutionDimensionNumbers& dnums =
630 conv->convolution_dimension_numbers();
631 ASSERT_EQ(dnums.input_spatial_dimensions().size(), 2);
632 ASSERT_EQ(dnums.kernel_spatial_dimensions().size(), 2);
633 ASSERT_EQ(dnums.output_spatial_dimensions().size(), 2);
634
635 EXPECT_EQ(dnums.input_batch_dimension(), 1);
636 EXPECT_EQ(dnums.input_spatial_dimensions()[0], 2);
637 EXPECT_EQ(dnums.input_spatial_dimensions()[1], 3);
638 EXPECT_EQ(dnums.input_feature_dimension(), 4);
639
640 EXPECT_EQ(dnums.kernel_spatial_dimensions()[0], 1);
641 EXPECT_EQ(dnums.kernel_spatial_dimensions()[1], 2);
642 EXPECT_EQ(dnums.kernel_input_feature_dimension(), 3);
643 EXPECT_EQ(dnums.kernel_output_feature_dimension(), 4);
644
645 EXPECT_EQ(dnums.output_batch_dimension(), 1);
646 EXPECT_EQ(dnums.output_spatial_dimensions()[0], 2);
647 EXPECT_EQ(dnums.output_spatial_dimensions()[1], 3);
648 EXPECT_EQ(dnums.output_feature_dimension(), 4);
649 }
650
TEST_F(CudnnVectorizeConvolutionsTest,NoVectorize4To32)651 TEST_F(CudnnVectorizeConvolutionsTest, NoVectorize4To32) {
652 auto module = ParseAndReturnVerifiedModule(R"(
653 HloModule TestModule
654
655 ENTRY TestComputation {
656 input = s8[10,20,30,16,4] parameter(0)
657 filter = s8[2,2,16,128,4] parameter(1)
658 bias = f32[10] parameter(2)
659 side_input = s8[10,20,30,16,4] parameter(3)
660 ROOT result = (s8[10,20,30,32,4], u8[0]) custom-call(input, filter, bias, side_input),
661 window={size=2x2}, dim_labels=b01f_01io->b01f,
662 custom_call_target="__cudnn$convForward"
663 })")
664 .ValueOrDie();
665 TF_ASSERT_OK_AND_ASSIGN(bool changed, Run({7, 0}, module.get()));
666 EXPECT_FALSE(changed);
667 }
668
669 } // namespace
670 } // namespace gpu
671 } // namespace xla
672