1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
17
18 #include <string>
19
20 #include "absl/strings/str_replace.h"
21 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
22 #include "tensorflow/compiler/xla/service/convert_mover.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
25 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
28 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
31 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
32 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
34 #include "tensorflow/compiler/xla/service/reshape_mover.h"
35 #include "tensorflow/compiler/xla/test_helpers.h"
36 #include "tensorflow/compiler/xla/tests/filecheck.h"
37 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/platform/test.h"
40
41 namespace xla {
42 namespace gpu {
43 namespace {
44
45 // TODO(b/210165681): The tests in this file are fragile to HLO op names.
46
47 namespace m = match;
48
49 using ::testing::HasSubstr;
50 using ::testing::Not;
51
52 class CudnnFusedConvRewriterHloTest : public HloTestBase {
53 public:
CudnnFusedConvRewriterHloTest()54 CudnnFusedConvRewriterHloTest()
55 : HloTestBase(/*verifier_layout_sensitive=*/false,
56 /*allow_mixed_precision_in_hlo_verifier=*/false,
57 /*instruction_can_change_layout_func=*/{}) {}
58 };
59
60 class CudnnFusedConvRewriterTest : public GpuCodegenTest {
61 protected:
GetOptimizedHlo(absl::string_view hlo_string)62 std::string GetOptimizedHlo(absl::string_view hlo_string) {
63 // cudnn_vectorize_convolutions transforms convolutions, making it hard to
64 // match them here in this test. What's worse, the transforms it does
65 // depends on the GPU that's available! So just disable them for this
66 // function that gets the optimized HLO. When we actually run the module
67 // we'll still have this pass enabled.
68 HloModuleConfig config = GetModuleConfigForTest();
69 DebugOptions debug_opts = config.debug_options();
70 debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
71 config.set_debug_options(debug_opts);
72
73 auto result = backend().compiler()->RunHloPasses(
74 ParseAndReturnVerifiedModule(hlo_string, config).value(),
75 backend().default_stream_executor(), backend().memory_allocator());
76 if (!result.status().ok()) {
77 TF_EXPECT_OK(result.status())
78 << "HLO compilation failed: " << result.status();
79 return "";
80 }
81 HloPrintOptions print_opts;
82 print_opts.set_print_operand_shape(false);
83 return (*result)->ToString(print_opts);
84 }
85
TestMatchWithAllTypes(absl::string_view hlo_string)86 void TestMatchWithAllTypes(absl::string_view hlo_string) {
87 for (absl::string_view type : {"f16", "f32", "f64"}) {
88 const std::string hlo_with_new_type =
89 absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
90 std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
91 EXPECT_THAT(optimized_hlo_string,
92 Not(HasSubstr(kCudnnConvForwardCallTarget)))
93 << optimized_hlo_string;
94 EXPECT_THAT(optimized_hlo_string,
95 HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
96 EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
97 << optimized_hlo_string;
98 }
99 }
100
TestClamp(absl::string_view pre_hlo_string,absl::string_view post_hlo_string)101 void TestClamp(absl::string_view pre_hlo_string,
102 absl::string_view post_hlo_string) {
103 std::string alpha_conv_scalar, alpha_side_input_scalar;
104 std::string elementwise_type;
105
106 std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
107 EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
108 EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
109 EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01}))
110 << pre_hlo_string;
111
112 StatusOr<bool> filecheck_result =
113 RunFileCheck(optimized_hlo_string, post_hlo_string);
114 ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
115 EXPECT_TRUE(*filecheck_result);
116 }
117
TestNotMatchWithAllTypes(absl::string_view hlo_string)118 void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
119 for (absl::string_view type : {"f16", "f32", "f64"}) {
120 const std::string hlo_with_new_type =
121 absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
122 std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
123 SCOPED_TRACE(optimized_hlo_string);
124 EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
125 EXPECT_THAT(optimized_hlo_string,
126 Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
127 }
128 }
129 };
130
TEST_F(CudnnFusedConvRewriterTest,TestConvOnly)131 TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) {
132 // max(0, conv(x, w));
133 TestMatchWithAllTypes(R"(
134 HloModule Test
135
136 ENTRY Test {
137 zero = TYPE[] constant(0)
138 zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
139
140 input = TYPE[1,17,9,9] parameter(0)
141 filter = TYPE[3,3,17,32] parameter(1)
142
143 conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
144 ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
145 })");
146 }
147
TEST_F(CudnnFusedConvRewriterTest,TestBias)148 TEST_F(CudnnFusedConvRewriterTest, TestBias) {
149 // max(0, conv(x, w) + bias);
150 TestMatchWithAllTypes(R"(
151 HloModule Test
152
153 ENTRY Test {
154 zero = TYPE[] constant(0)
155 zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
156
157 input = TYPE[1,3,3,64] parameter(0)
158 filter = TYPE[3,3,64,64] parameter(1)
159 bias = TYPE[64] parameter(2)
160
161 conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
162 broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
163 add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
164 ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
165 })");
166 }
167
TEST_F(CudnnFusedConvRewriterTest,TestSideInputOnly)168 TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) {
169 // max(0, conv(x, w) + side_input);
170 TestMatchWithAllTypes(R"(
171 HloModule Test
172
173 ENTRY Test {
174 zero = TYPE[] constant(0)
175 zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
176
177 input = TYPE[1,3,3,64] parameter(0)
178 filter = TYPE[3,3,64,64] parameter(1)
179 side_input = TYPE[1,3,3,64] parameter(2)
180
181 conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
182 add1 = TYPE[1,3,3,64] add(conv, side_input)
183 ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
184 })");
185 }
186
TEST_F(CudnnFusedConvRewriterTest,TestBiasAndSideInput)187 TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) {
188 // max(0, conv(x, w) + side_input + bias);
189 TestMatchWithAllTypes(R"(
190 HloModule Test
191
192 ENTRY Test {
193 zero = TYPE[] constant(0)
194 zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
195
196 input = TYPE[1,3,3,64] parameter(0)
197 filter = TYPE[3,3,64,64] parameter(1)
198 side_input = TYPE[1,3,3,64] parameter(2)
199 bias = TYPE[64] parameter(3)
200
201 conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
202 broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
203 add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
204 add2 = TYPE[1,3,3,64] add(add1, side_input)
205 ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
206 })");
207 }
208
TEST_F(CudnnFusedConvRewriterTest,TestScaledConv)209 TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) {
210 // max(0, 0.999994934 * conv(x, w));
211 TestMatchWithAllTypes(R"(
212 HloModule Test
213
214 ENTRY Test {
215 zero = TYPE[] constant(0)
216 zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
217 alpha_conv_scalar = TYPE[] constant(0.999994934)
218
219 input = TYPE[1,17,9,9] parameter(0)
220 filter = TYPE[3,3,17,32] parameter(1)
221
222 conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
223 alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
224 scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
225 ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
226 })");
227 }
228
TEST_F(CudnnFusedConvRewriterTest,TestNoCrashOnInf)229 TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) {
230 EXPECT_TRUE(RunAndCompare(R"(
231 HloModule Test
232
233 ENTRY Test {
234 zero = f32[] constant(inf)
235 zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
236 alpha_conv_scalar = f32[] constant(0.999994934)
237
238 input = f32[1,17,9,9] parameter(0)
239 filter = f32[3,3,17,32] parameter(1)
240
241 conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
242 alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
243 scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv)
244 ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv)
245 })",
246 ErrorSpec{0.01}));
247 }
248
TEST_F(CudnnFusedConvRewriterTest,TestScaledConvAndSideInput)249 TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) {
250 // max(0, conv(x, w) + 0.899994934 * side_input);
251 TestMatchWithAllTypes(R"(
252 HloModule Test
253
254 ENTRY Test {
255 zero = TYPE[] constant(0)
256 zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
257 alpha_side_input_scalar = TYPE[] constant(0.899994934)
258 alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
259
260 input = TYPE[1,3,3,64] parameter(0)
261 filter = TYPE[3,3,64,64] parameter(1)
262 side_input = TYPE[1,3,3,64] parameter(2)
263
264 conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
265 scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
266 add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
267 ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
268 })");
269 }
270
TEST_F(CudnnFusedConvRewriterTest,TestScaledConvAndScaledSideInput)271 TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) {
272 // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
273 TestMatchWithAllTypes(R"(
274 HloModule Test
275
276 ENTRY Test {
277 zero = TYPE[] constant(0)
278 zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
279 alpha_conv_scalar = TYPE[] constant(0.999994934)
280 alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
281 alpha_side_input_scalar = TYPE[] constant(0.899994934)
282 alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
283
284 input = TYPE[1,3,3,64] parameter(0)
285 filter = TYPE[3,3,64,64] parameter(1)
286 side_input = TYPE[1,3,3,64] parameter(2)
287
288 conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
289 scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
290 scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
291 add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
292 ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
293 })");
294 }
295
TEST_F(CudnnFusedConvRewriterTest,TestScaledConvAndScaledSideInputWithBias)296 TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) {
297 // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
298 TestMatchWithAllTypes(R"(
299 HloModule Test
300
301 ENTRY Test {
302 zero = TYPE[] constant(0)
303 zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
304 alpha_conv_scalar = TYPE[] constant(0.999994934)
305 alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
306 alpha_side_input_scalar = TYPE[] constant(0.899994934)
307 alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
308
309 input = TYPE[1,3,3,64] parameter(0)
310 filter = TYPE[3,3,64,64] parameter(1)
311 side_input = TYPE[1,3,3,64] parameter(2)
312 bias = TYPE[64] parameter(3)
313
314 conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
315 scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
316 scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
317 broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
318 add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
319 add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
320 ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
321 })");
322 }
323
TEST_F(CudnnFusedConvRewriterTest,TestMatchMaxZeroOnly)324 TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) {
325 // max(0.1, conv(x, w)) shouldn't match.
326 TestNotMatchWithAllTypes(R"(
327 HloModule Test
328
329 ENTRY Test {
330 point_one = TYPE[] constant(0.1)
331 point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
332
333 input = TYPE[1,17,9,9] parameter(0)
334 filter = TYPE[3,3,17,32] parameter(1)
335
336 conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
337 ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
338 })");
339 }
340
TEST_F(CudnnFusedConvRewriterTest,PreservesMetadata)341 TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
342 const char* kHloString = R"(
343 HloModule Test
344
345 ENTRY Test {
346 zero = f32[] constant(0)
347 zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
348
349 input = f32[1,17,9,9] parameter(0)
350 filter = f32[3,3,17,32] parameter(1)
351
352 conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
353 ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
354 })";
355
356 const std::string optimized_hlo_string =
357 backend()
358 .compiler()
359 ->RunHloPasses(
360 ParseAndReturnVerifiedModule(kHloString, GetModuleConfigForTest())
361 .value(),
362 backend().default_stream_executor(), backend().memory_allocator())
363 .value()
364 ->ToString();
365 EXPECT_THAT(optimized_hlo_string,
366 ::testing::ContainsRegex(
367 R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
368 }
369
TEST_F(CudnnFusedConvRewriterTest,TestPreservesFeatureGroupCount)370 TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
371 // The convolution below would crash if feature_count is not preserved.
372 const char* kHloString = R"(
373 HloModule jaxpr_computation__6.19
374
375 primitive_computation__1.4 {
376 parameter.5 = f32[] parameter(0)
377 parameter.6 = f32[] parameter(1)
378 ROOT add.7 = f32[] add(parameter.5, parameter.6)
379 }
380
381 ENTRY jaxpr_computation__7.8 {
382 parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1)
383 parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0)
384 convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53
385 constant.13 = f32[] constant(0)
386 broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={}
387 maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14)
388 ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4
389 }
390 )";
391 EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01}));
392 }
393
TEST_F(CudnnFusedConvRewriterTest,TestConvInt8ToInt8)394 TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) {
395 // max(0, clamp(conv(x, w)))); for int8_t
396 TestClamp(
397 // pre_hlo
398 R"(
399 HloModule Test
400
401 ENTRY Test {
402 zero = s8[] constant(0)
403 zeros = s8[1,32,9,9] broadcast(zero), dimensions={}
404
405 input = s8[1,17,9,9] parameter(0)
406 filter = s8[3,3,17,32] parameter(1)
407
408 inputs32 = s32[1,17,9,9] convert(input)
409 filters32 = s32[3,3,17,32] convert(filter)
410
411 conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
412
413 lower = s32[] constant(-128)
414 lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
415 upper = s32[] constant(127)
416 uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
417
418 clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
419
420 ROOT convert = s8[1,32,9,9] convert(clamp)
421 })",
422 // post_hlo
423 R"(
424 ; CHECK-LABEL: ENTRY %Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
425 ; CHECK: %cudnn-conv{{(\.[0-9])?}} = (s8[1,32,9,9]{1,3,2,0}, u8[{{[0-9]*}}]{0}) custom-call(%fusion{{(\.[0-9])?}}, %fusion{{(\.[0-9])?}}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config=
426 )");
427 }
428
TEST_F(CudnnFusedConvRewriterHloTest,TestConvInt8ToFloat)429 TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
430 const std::string module_str = R"(
431 HloModule Test
432
433 ENTRY Test {
434 input = s8[1,17,9,9] parameter(0)
435 filter = s8[3,3,17,32] parameter(1)
436
437 inputs32 = s32[1,17,9,9] convert(input)
438 filters32 = s32[3,3,17,32] convert(filter)
439
440 conv = s32[1,32,9,9] convolution(inputs32, filters32),
441 window={size=3x3 pad=1_1x1_1},
442 dim_labels=bf01_01io->bf01
443
444 ROOT convert = f32[1,32,9,9] convert(conv)
445 })";
446 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
447
448 GpuConvRewriter rewriter;
449 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
450 CudnnFusedConvRewriter fuser;
451 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
452
453 SCOPED_TRACE(m->ToString());
454 EXPECT_THAT(m->entry_computation()->root_instruction(),
455 GmockMatch(m::GetTupleElement(
456 m::CustomCall(kCudnnConvForwardCallTarget), 0)
457 .WithShape(F32, {1, 32, 9, 9})));
458 }
459
TEST_F(CudnnFusedConvRewriterHloTest,TestConvInt8ToInt8BiasSideInput)460 TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
461 const std::string module_str = R"(
462 HloModule Test
463
464 ENTRY Test {
465 input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
466 filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
467 bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
468 side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
469
470 conv = s32[1,32,9,9] convolution(input, filter),
471 window={size=3x3 pad=1_1x1_1},
472 dim_labels=bf01_01io->bf01
473 conv_f32 = f32[1,32,9,9] convert(conv)
474 ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
475 add(add(conv_f32, bias), side_input),
476 f32[1,32,9,9] broadcast(f32[] constant(127))))
477 })";
478 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
479
480 GpuConvRewriter rewriter;
481 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
482 CudnnFusedConvRewriter fuser;
483 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
484
485 // Simplify new `convert`'s that may be added to the graph.
486 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
487 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
488
489 SCOPED_TRACE(m->ToString());
490 EXPECT_THAT(
491 m->entry_computation()->root_instruction(),
492 GmockMatch(m::GetTupleElement(
493 m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
494 m::Parameter(0), m::Parameter(1),
495 m::Parameter(2), m::Parameter(3)),
496 0)
497 .WithShape(S8, {1, 32, 9, 9})));
498 }
499
TEST_F(CudnnFusedConvRewriterHloTest,TestReluAfterConvert)500 TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
501 const std::string module_str = R"(
502 HloModule Test
503
504 ENTRY Test {
505 input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
506 filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
507
508 conv = s32[1,32,9,9] convolution(input, filter),
509 window={size=3x3 pad=1_1x1_1},
510 dim_labels=bf01_01io->bf01
511 conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
512 conv,
513 s32[1,32,9,9] broadcast(s32[] constant(127))))
514 zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
515 ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
516 })";
517 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
518
519 GpuConvRewriter rewriter;
520 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
521 CudnnFusedConvRewriter fuser;
522 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
523
524 // Simplify new `convert`'s that may be added to the graph.
525 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
526 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
527
528 SCOPED_TRACE(m->ToString());
529 const HloInstruction* conv;
530 ASSERT_THAT(
531 m->entry_computation()->root_instruction(),
532 GmockMatch(
533 m::GetTupleElement(
534 m::CustomCall(
535 &conv, kCudnnConvBiasActivationForwardCallTarget,
536 m::Parameter(0), //
537 m::Parameter(1), //
538 m::Broadcast(
539 m::ConstantEffectiveScalar(0).WithElementType(F32))),
540 0)
541 .WithShape(S8, {1, 32, 9, 9})));
542 TF_ASSERT_OK_AND_ASSIGN(auto config,
543 conv->backend_config<CudnnConvBackendConfig>());
544 EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
545 }
546
TEST_F(CudnnFusedConvRewriterHloTest,TestConvInt8ToFloatBiasSideInput)547 TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
548 const std::string module_str = R"(
549 HloModule Test
550
551 ENTRY Test {
552 input = s8[1,17,9,9] parameter(0)
553 filter = s8[3,3,17,32] parameter(1)
554 bias = f32[32] parameter(2)
555 bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
556 side_input_f32 = f32[1,32,9,9] parameter(3)
557
558 inputs32 = s32[1,17,9,9] convert(input)
559 filters32 = s32[3,3,17,32] convert(filter)
560
561 conv = s32[1,32,9,9] convolution(inputs32, filters32),
562 window={size=3x3 pad=1_1x1_1},
563 dim_labels=bf01_01io->bf01
564 conv_f32 = f32[1,32,9,9] convert(conv)
565 sum1 = add(conv_f32, bias_broadcast)
566 ROOT sum2 = add(sum1, side_input_f32)
567 })";
568 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
569
570 GpuConvRewriter rewriter;
571 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
572 CudnnFusedConvRewriter fuser;
573 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
574
575 // Simplify new `convert`'s that may be added to the graph.
576 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
577 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
578
579 SCOPED_TRACE(m->ToString());
580 EXPECT_THAT(
581 m->entry_computation()->root_instruction(),
582 GmockMatch(m::GetTupleElement(
583 m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
584 m::Parameter(0), m::Parameter(1),
585 m::Parameter(2), m::Parameter(3)),
586 0)
587 .WithShape(F32, {1, 32, 9, 9})));
588 }
589
590 // The ReshapeMover pass changes
591 // reshape(side_input) * alpha -->
592 // reshape(side_input * alpha).
593 // Make sure we can pattern-match this.
TEST_F(CudnnFusedConvRewriterHloTest,Int8SideInputWithScaleAndReshape)594 TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) {
595 const std::string module_str = R"(
596 HloModule Test
597
598 ENTRY Test {
599 input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
600 filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
601 bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
602 side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
603 side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
604
605 conv = s32[1,32,9,9] convolution(input, filter),
606 window={size=3x3 pad=1_1x1_1},
607 dim_labels=bf01_01io->bf01
608 ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
609 add(add(f32[1,32,9,9] convert(conv), bias), side_input),
610 f32[1,32,9,9] broadcast(f32[] constant(127))))
611 })";
612 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
613
614 GpuConvRewriter rewriter;
615 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
616 CudnnFusedConvRewriter fuser;
617 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
618
619 // Simplify new `convert`'s that may be added to the graph.
620 HloPassFix<HloPassPipeline> simplify("simplify");
621 simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
622 simplify.AddPass<ReshapeMover>();
623 simplify.AddPass<ConvertMover>();
624 TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
625
626 SCOPED_TRACE(m->ToString());
627 const HloInstruction* conv = nullptr;
628 ASSERT_THAT(
629 m->entry_computation()->root_instruction(),
630 GmockMatch(
631 m::GetTupleElement(
632 m::CustomCall(
633 &conv, kCudnnConvBiasActivationForwardCallTarget,
634 m::Parameter(0), //
635 m::Parameter(1), //
636 m::Parameter(2), //
637 m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
638 0)
639 .WithShape(S8, {1, 32, 9, 9})));
640 TF_ASSERT_OK_AND_ASSIGN(auto config,
641 conv->backend_config<CudnnConvBackendConfig>());
642 EXPECT_EQ(config.conv_result_scale(), 1);
643 EXPECT_EQ(config.side_input_scale(), 0.25);
644 }
645
TEST_F(CudnnFusedConvRewriterHloTest,FuseAlpha)646 TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
647 const std::string module_str = R"(
648 HloModule Test
649
650 ENTRY Test {
651 input = s8[1,17,9,9] parameter(0)
652 filter = s8[3,3,17,32] parameter(1)
653 inputs32 = s32[1,17,9,9] convert(input)
654 filters32 = s32[3,3,17,32] convert(filter)
655 alpha = f32[] constant(42)
656 alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
657
658 conv = s32[1,32,9,9] convolution(inputs32, filters32),
659 window={size=3x3 pad=1_1x1_1},
660 dim_labels=bf01_01io->bf01
661 convert = f32[1,32,9,9] convert(conv)
662 ROOT root = multiply(convert, alpha_broadcast)
663 })";
664 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
665
666 GpuConvRewriter rewriter;
667 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
668 CudnnFusedConvRewriter fuser;
669 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
670
671 SCOPED_TRACE(m->ToString());
672 const HloInstruction* conv = nullptr;
673 ASSERT_THAT(
674 m->entry_computation()->root_instruction(),
675 GmockMatch(
676 m::GetTupleElement(
677 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget),
678 0)
679 .WithShape(F32, {1, 32, 9, 9})));
680 TF_ASSERT_OK_AND_ASSIGN(auto config,
681 conv->backend_config<CudnnConvBackendConfig>());
682 EXPECT_EQ(config.conv_result_scale(), 42);
683 }
684
TEST_F(CudnnFusedConvRewriterHloTest,FuseRelu)685 TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
686 const std::string module_str = R"(
687 HloModule Test
688
689 ENTRY Test {
690 inputs = f32[1,17,9,9] parameter(0)
691 filters = f32[3,3,17,32] parameter(1)
692 bias = f32[32] parameter(2)
693 bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
694 zero = f32[] constant(0)
695 zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
696 conv = f32[1,32,9,9] convolution(inputs, filters),
697 window={size=3x3 pad=1_1x1_1},
698 dim_labels=bf01_01io->bf01
699 sum = add(conv, bias_broadcast)
700 ROOT relu = maximum(sum, zeros)
701 })";
702 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
703
704 GpuConvRewriter rewriter;
705 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
706 CudnnFusedConvRewriter fuser;
707 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
708
709 SCOPED_TRACE(m->ToString());
710 const HloInstruction* conv;
711 ASSERT_THAT(
712 m->entry_computation()->root_instruction(),
713 GmockMatch(
714 m::GetTupleElement(
715 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
716 m::Parameter(0), m::Parameter(1), m::Parameter(2)),
717 0)
718 .WithShape(F32, {1, 32, 9, 9})));
719 TF_ASSERT_OK_AND_ASSIGN(auto config,
720 conv->backend_config<CudnnConvBackendConfig>());
721 EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
722 }
723
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseReluIfMultipleUses)724 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
725 const std::string module_str = R"(
726 HloModule Test
727
728 ENTRY Test {
729 inputs = f32[1,17,9,9] parameter(0)
730 filters = f32[3,3,17,32] parameter(1)
731 bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
732 zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
733 conv = f32[1,32,9,9] convolution(inputs, filters),
734 window={size=3x3 pad=1_1x1_1},
735 dim_labels=bf01_01io->bf01
736 sum = add(conv, bias)
737 relu = maximum(sum, zeros)
738 not_relu = minimum(sum, zeros)
739 ROOT root = tuple(relu, not_relu)
740 })";
741 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
742
743 GpuConvRewriter rewriter;
744 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
745 CudnnFusedConvRewriter fuser;
746 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
747
748 SCOPED_TRACE(m->ToString());
749 const HloInstruction* conv;
750 ASSERT_THAT(
751 m->entry_computation()->root_instruction(),
752 GmockMatch(m::Tuple(
753 m::MaximumAnyOrder(
754 m::Broadcast(m::ConstantEffectiveScalar(0)),
755 m::GetTupleElement(
756 m::CustomCall(
757 &conv, kCudnnConvBiasActivationForwardCallTarget,
758 m::Parameter(0), m::Parameter(1), m::Parameter(2)),
759 0)
760 .WithShape(F32, {1, 32, 9, 9})),
761 m::Minimum())));
762 TF_ASSERT_OK_AND_ASSIGN(auto config,
763 conv->backend_config<CudnnConvBackendConfig>());
764 EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
765 }
766
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseAlphaIfMultipleUsers)767 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
768 const std::string module_str = R"(
769 HloModule Test
770
771 ENTRY Test {
772 inputs = f32[1,17,9,9] parameter(0)
773 filters = f32[3,3,17,32] parameter(1)
774 bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
775 alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
776 conv = f32[1,32,9,9] convolution(inputs, filters),
777 window={size=3x3 pad=1_1x1_1},
778 dim_labels=bf01_01io->bf01
779 sum = add(multiply(alpha, conv), bias)
780 ROOT root = tuple(conv, sum)
781 })";
782 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
783
784 GpuConvRewriter rewriter;
785 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
786 CudnnFusedConvRewriter fuser;
787 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
788
789 SCOPED_TRACE(m->ToString());
790 const HloInstruction* conv1;
791 const HloInstruction* conv2;
792 ASSERT_THAT(
793 m->entry_computation()->root_instruction(),
794 GmockMatch(m::Tuple(
795 m::GetTupleElement(m::CustomCall(&conv1), 0),
796 m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
797 m::MultiplyAnyOrder(
798 m::Broadcast(m::Parameter(3)),
799 m::GetTupleElement(m::CustomCall(&conv2), 0))))));
800 EXPECT_EQ(conv1, conv2);
801 TF_ASSERT_OK_AND_ASSIGN(auto config,
802 conv1->backend_config<CudnnConvBackendConfig>());
803 EXPECT_EQ(config.conv_result_scale(), 1);
804 EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
805 }
806
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseBiasIfMultipleUsers)807 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
808 const std::string module_str = R"(
809 HloModule Test
810
811 ENTRY Test {
812 inputs = f32[1,17,9,9] parameter(0)
813 filters = f32[3,3,17,32] parameter(1)
814 bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
815 conv = f32[1,32,9,9] convolution(inputs, filters),
816 window={size=3x3 pad=1_1x1_1},
817 dim_labels=bf01_01io->bf01
818 ROOT root = tuple(conv, add(conv, bias))
819 })";
820 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
821
822 GpuConvRewriter rewriter;
823 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
824 CudnnFusedConvRewriter fuser;
825 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
826
827 SCOPED_TRACE(m->ToString());
828 const HloInstruction* conv1;
829 const HloInstruction* conv2;
830 ASSERT_THAT(
831 m->entry_computation()->root_instruction(),
832 GmockMatch(m::Tuple(
833 m::GetTupleElement(m::CustomCall(&conv1), 0),
834 m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
835 m::GetTupleElement(m::CustomCall(&conv2), 0)))));
836 EXPECT_EQ(conv1, conv2);
837 TF_ASSERT_OK_AND_ASSIGN(auto config,
838 conv1->backend_config<CudnnConvBackendConfig>());
839 EXPECT_EQ(config.conv_result_scale(), 1);
840 EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
841 }
842
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseSideInputThroughRelu)843 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) {
844 const std::string module_str = R"(
845 HloModule Test
846
847 ENTRY Test {
848 inputs = f32[1,17,9,9] parameter(0)
849 filters = f32[3,3,17,32] parameter(1)
850 side_input = f32[1,32,9,9] parameter(2)
851 conv = f32[1,32,9,9] convolution(inputs, filters),
852 window={size=3x3 pad=1_1x1_1},
853 dim_labels=bf01_01io->bf01
854 relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
855 ROOT root = add(relu, side_input)
856 })";
857 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
858
859 GpuConvRewriter rewriter;
860 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
861 CudnnFusedConvRewriter fuser;
862 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
863
864 SCOPED_TRACE(m->ToString());
865 const HloInstruction* conv;
866 ASSERT_THAT(
867 m->entry_computation()->root_instruction(),
868 GmockMatch(m::AddAnyOrder(
869 m::Parameter(2),
870 m::GetTupleElement(
871 m::CustomCall(&conv, m::Parameter(0), m::Parameter(1),
872 m::Broadcast(m::ConstantEffectiveScalar(0))),
873 0))));
874 TF_ASSERT_OK_AND_ASSIGN(auto config,
875 conv->backend_config<CudnnConvBackendConfig>());
876 EXPECT_EQ(config.conv_result_scale(), 1);
877 EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
878 }
879
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseBiasThroughRelu)880 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) {
881 const std::string module_str = R"(
882 HloModule Test
883
884 ENTRY Test {
885 inputs = f32[1,17,9,9] parameter(0)
886 filters = f32[3,3,17,32] parameter(1)
887 bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
888 conv = f32[1,32,9,9] convolution(inputs, filters),
889 window={size=3x3 pad=1_1x1_1},
890 dim_labels=bf01_01io->bf01
891 relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
892 ROOT root = add(relu, bias)
893 })";
894 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
895
896 GpuConvRewriter rewriter;
897 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
898 CudnnFusedConvRewriter fuser;
899 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
900
901 SCOPED_TRACE(m->ToString());
902 const HloInstruction* conv;
903 ASSERT_THAT(m->entry_computation()->root_instruction(),
904 GmockMatch(m::AddAnyOrder(
905 m::Broadcast(m::Parameter(2)),
906 m::GetTupleElement(m::CustomCall(
907 &conv, m::Parameter(0), m::Parameter(1),
908 m::Broadcast(m::ConstantEffectiveScalar(0)))))));
909 TF_ASSERT_OK_AND_ASSIGN(auto config,
910 conv->backend_config<CudnnConvBackendConfig>());
911 EXPECT_EQ(config.conv_result_scale(), 1);
912 EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
913 }
914
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseSideInputIfMultipleUsers)915 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
916 const std::string module_str = R"(
917 HloModule Test
918
919 ENTRY Test {
920 inputs = f32[1,17,9,9] parameter(0)
921 filters = f32[3,3,17,32] parameter(1)
922 side_input = f32[1,32,9,9] parameter(2)
923 conv = f32[1,32,9,9] convolution(inputs, filters),
924 window={size=3x3 pad=1_1x1_1},
925 dim_labels=bf01_01io->bf01
926 ROOT root = tuple(conv, add(conv, side_input))
927 })";
928 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
929
930 GpuConvRewriter rewriter;
931 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
932 CudnnFusedConvRewriter fuser;
933 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
934
935 SCOPED_TRACE(m->ToString());
936 const HloInstruction* conv1;
937 const HloInstruction* conv2;
938 ASSERT_THAT(
939 m->entry_computation()->root_instruction(),
940 GmockMatch(m::Tuple(
941 m::GetTupleElement(m::CustomCall(&conv1), 0),
942 m::AddAnyOrder(m::Parameter(2),
943 m::GetTupleElement(m::CustomCall(&conv2), 0)))));
944 EXPECT_EQ(conv1, conv2);
945 TF_ASSERT_OK_AND_ASSIGN(auto config,
946 conv1->backend_config<CudnnConvBackendConfig>());
947 EXPECT_EQ(config.conv_result_scale(), 1);
948 EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
949 }
950
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseConvertToF16IfMultipleUsers)951 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
952 const std::string module_str = R"(
953 HloModule Test
954
955 ENTRY Test {
956 inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
957 filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
958 conv = f32[1,32,9,9] convolution(inputs, filters),
959 window={size=3x3 pad=1_1x1_1},
960 dim_labels=bf01_01io->bf01
961 ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
962 })";
963 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
964
965 GpuConvRewriter rewriter;
966 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
967 CudnnFusedConvRewriter fuser;
968 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
969
970 SCOPED_TRACE(m->ToString());
971 const HloInstruction* conv1;
972 const HloInstruction* conv2;
973 ASSERT_THAT(m->entry_computation()->root_instruction(),
974 GmockMatch(m::Tuple(
975 m::GetTupleElement(m::CustomCall(&conv1), 0),
976 m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
977 EXPECT_EQ(conv1, conv2);
978 }
979
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseToS8IfMultipleUsers)980 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
981 const std::string module_str = R"(
982 HloModule Test
983
984 ENTRY Test {
985 inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
986 filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
987 conv = f32[1,32,9,9] convolution(inputs, filters),
988 window={size=3x3 pad=1_1x1_1},
989 dim_labels=bf01_01io->bf01
990 conv_s8 = s8[1,32,9,9] convert(clamp(
991 f32[1,32,9,9] broadcast(f32[] constant(-128)),
992 conv,
993 f32[1,32,9,9] broadcast(f32[] constant(127))))
994 ROOT root = tuple(conv, conv_s8)
995 })";
996 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
997
998 GpuConvRewriter rewriter;
999 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1000 CudnnFusedConvRewriter fuser;
1001 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1002
1003 SCOPED_TRACE(m->ToString());
1004 const HloInstruction* conv1;
1005 const HloInstruction* conv2;
1006 ASSERT_THAT(
1007 m->entry_computation()->root_instruction(),
1008 GmockMatch(m::Tuple(
1009 m::GetTupleElement(m::CustomCall(&conv1), 0),
1010 m::Convert(m::Clamp(m::Op(), //
1011 m::GetTupleElement(m::CustomCall(&conv2), 0),
1012 m::Op())))));
1013 EXPECT_EQ(conv1, conv2);
1014 }
1015
TEST_F(CudnnFusedConvRewriterHloTest,FuseBias)1016 TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
1017 const std::string module_str = R"(
1018 HloModule Test
1019
1020 ENTRY Test {
1021 inputs = f32[1,17,9,9] parameter(0)
1022 filters = f32[3,3,17,32] parameter(1)
1023 bias = f32[32] parameter(2)
1024 bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
1025 conv = f32[1,32,9,9] convolution(inputs, filters),
1026 window={size=3x3 pad=1_1x1_1},
1027 dim_labels=bf01_01io->bf01
1028 ROOT root = add(conv, bias_broadcast)
1029 })";
1030 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1031
1032 GpuConvRewriter rewriter;
1033 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1034 CudnnFusedConvRewriter fuser;
1035 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1036
1037 SCOPED_TRACE(m->ToString());
1038 ASSERT_THAT(
1039 m->entry_computation()->root_instruction(),
1040 GmockMatch(
1041 m::GetTupleElement(
1042 m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
1043 m::Parameter(0), m::Parameter(1), m::Parameter(2)),
1044 0)
1045 .WithShape(F32, {1, 32, 9, 9})));
1046 }
1047
TEST_F(CudnnFusedConvRewriterHloTest,FuseSideInput)1048 TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
1049 const std::string module_str = R"(
1050 HloModule Test
1051
1052 ENTRY Test {
1053 inputs = f32[1,17,9,9] parameter(0)
1054 filters = f32[3,3,17,32] parameter(1)
1055 side_input = f32[1,32,9,9] parameter(2)
1056 conv = f32[1,32,9,9] convolution(inputs, filters),
1057 window={size=3x3 pad=1_1x1_1},
1058 dim_labels=bf01_01io->bf01
1059 ROOT root = add(conv, side_input)
1060 })";
1061 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1062
1063 GpuConvRewriter rewriter;
1064 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1065 CudnnFusedConvRewriter fuser;
1066 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1067
1068 SCOPED_TRACE(m->ToString());
1069 const HloInstruction* conv;
1070 ASSERT_THAT(
1071 m->entry_computation()->root_instruction(),
1072 GmockMatch(
1073 m::GetTupleElement(
1074 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1075 m::Parameter(0), m::Parameter(1),
1076 m::Broadcast(m::ConstantEffectiveScalar(0))
1077 .WithShape(F32, {32}),
1078 m::Parameter(2)),
1079 0)
1080 .WithShape(F32, {1, 32, 9, 9})));
1081 TF_ASSERT_OK_AND_ASSIGN(auto config,
1082 conv->backend_config<CudnnConvBackendConfig>());
1083 EXPECT_EQ(config.side_input_scale(), 1);
1084 }
1085
TEST_F(CudnnFusedConvRewriterHloTest,FuseScaledSideInput)1086 TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
1087 const std::string module_str = R"(
1088 HloModule Test
1089
1090 ENTRY Test {
1091 inputs = f32[1,17,9,9] parameter(0)
1092 filters = f32[3,3,17,32] parameter(1)
1093 side_input = f32[1,32,9,9] parameter(2)
1094 side_input_scale = f32[] constant(42)
1095 side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
1096 side_input_product = multiply(side_input, side_input_scale_broadcast)
1097 conv = f32[1,32,9,9] convolution(inputs, filters),
1098 window={size=3x3 pad=1_1x1_1},
1099 dim_labels=bf01_01io->bf01
1100 ROOT root = add(conv, side_input_product)
1101 })";
1102 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1103
1104 GpuConvRewriter rewriter;
1105 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1106 CudnnFusedConvRewriter fuser;
1107 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1108
1109 SCOPED_TRACE(m->ToString());
1110 const HloInstruction* conv;
1111 ASSERT_THAT(
1112 m->entry_computation()->root_instruction(),
1113 GmockMatch(
1114 m::GetTupleElement(
1115 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1116 m::Parameter(0), m::Parameter(1),
1117 m::Broadcast(m::ConstantEffectiveScalar(0))
1118 .WithShape(F32, {32}),
1119 m::Parameter(2)),
1120 0)
1121 .WithShape(F32, {1, 32, 9, 9})));
1122 TF_ASSERT_OK_AND_ASSIGN(auto config,
1123 conv->backend_config<CudnnConvBackendConfig>());
1124 EXPECT_EQ(config.side_input_scale(), 42);
1125 }
1126
TEST_F(CudnnFusedConvRewriterHloTest,FuseBiasAndSideInput)1127 TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
1128 const std::string module_str = R"(
1129 HloModule Test
1130
1131 ENTRY Test {
1132 inputs = f32[1,17,9,9] parameter(0)
1133 filters = f32[3,3,17,32] parameter(1)
1134 bias = f32[32] parameter(2)
1135 side_input = f32[1,32,9,9] parameter(3)
1136 bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
1137 conv = f32[1,32,9,9] convolution(inputs, filters),
1138 window={size=3x3 pad=1_1x1_1},
1139 dim_labels=bf01_01io->bf01
1140 sum = add(conv, side_input)
1141 ROOT sum2 = add(sum, bias_broadcast)
1142 })";
1143 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1144
1145 GpuConvRewriter rewriter;
1146 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1147 CudnnFusedConvRewriter fuser;
1148 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1149
1150 SCOPED_TRACE(m->ToString());
1151 const HloInstruction* conv;
1152 ASSERT_THAT(
1153 m->entry_computation()->root_instruction(),
1154 GmockMatch(
1155 m::GetTupleElement(
1156 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1157 m::Parameter(0), m::Parameter(1), m::Parameter(2),
1158 m::Parameter(3)),
1159 0)
1160 .WithShape(F32, {1, 32, 9, 9})));
1161 TF_ASSERT_OK_AND_ASSIGN(auto config,
1162 conv->backend_config<CudnnConvBackendConfig>());
1163 EXPECT_EQ(config.side_input_scale(), 1);
1164 }
1165
TEST_F(CudnnFusedConvRewriterHloTest,EffectiveScalarBias)1166 TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
1167 const std::string module_str = R"(
1168 HloModule Test
1169
1170 ENTRY Test {
1171 inputs = f32[1,17,9,9] parameter(0)
1172 filters = f32[3,3,17,32] parameter(1)
1173 bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
1174 conv = f32[1,32,9,9] convolution(inputs, filters),
1175 window={size=3x3 pad=1_1x1_1},
1176 dim_labels=bf01_01io->bf01
1177 ROOT root = add(conv, bias)
1178 })";
1179 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1180
1181 GpuConvRewriter rewriter;
1182 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1183 CudnnFusedConvRewriter fuser;
1184 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1185
1186 SCOPED_TRACE(m->ToString());
1187 const HloInstruction* conv;
1188 ASSERT_THAT(
1189 m->entry_computation()->root_instruction(),
1190 GmockMatch(
1191 m::GetTupleElement(
1192 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1193 m::Parameter(0), m::Parameter(1),
1194 m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
1195 0)
1196 .WithShape(F32, {1, 32, 9, 9})));
1197 }
1198
TEST_F(CudnnFusedConvRewriterHloTest,StrengthReduceF32ToF16)1199 TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
1200 const std::string module_str = R"(
1201 HloModule Test
1202
1203 ENTRY Test {
1204 inputs = f16[1,17,9,9] parameter(0)
1205 filters = f16[3,3,17,32] parameter(1)
1206 bias = f16[32] parameter(2)
1207 side_input = f16[1,32,9,9] parameter(3)
1208
1209 inputs_f32 = f32[1,17,9,9] convert(inputs)
1210 filters_f32 = f32[3,3,17,32] convert(filters)
1211 bias_f32 = f32[32] convert(bias)
1212 bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
1213 side_input_f32 = f32[1,32,9,9] convert(side_input)
1214 conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
1215 window={size=3x3 pad=1_1x1_1},
1216 dim_labels=bf01_01io->bf01
1217 sum = add(conv, side_input_f32)
1218 sum2 = add(sum, bias_broadcast)
1219 ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
1220 })";
1221 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1222
1223 GpuConvRewriter rewriter;
1224 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1225 CudnnFusedConvRewriter fuser;
1226 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1227
1228 // Simplify new `convert`'s that may be added to the graph.
1229 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1230 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1231
1232 SCOPED_TRACE(m->ToString());
1233 const HloInstruction* conv;
1234 ASSERT_THAT(
1235 m->entry_computation()->root_instruction(),
1236 GmockMatch(
1237 m::GetTupleElement(
1238 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1239 m::Parameter(0), m::Parameter(1), m::Parameter(2),
1240 m::Parameter(3)),
1241 0)
1242 .WithShape(F16, {1, 32, 9, 9})));
1243 TF_ASSERT_OK_AND_ASSIGN(auto config,
1244 conv->backend_config<CudnnConvBackendConfig>());
1245 EXPECT_EQ(config.side_input_scale(), 1);
1246 }
1247
1248 // We should be able to lower this to an f16 convolution even though the
1249 // f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
TEST_F(CudnnFusedConvRewriterHloTest,BroadcastReshapeTransposeAfterConvert)1250 TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
1251 const std::string module_str = R"(
1252 HloModule Test
1253
1254 ENTRY Test {
1255 inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
1256 filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
1257 bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
1258 side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
1259
1260 conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
1261 window={size=3x3 pad=1_1x1_1},
1262 dim_labels=bf01_01io->bf01
1263 conv_f16 = f16[1,32,9,9] convert(conv_f32)
1264 ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
1265 })";
1266 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1267
1268 GpuConvRewriter rewriter;
1269 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1270 CudnnFusedConvRewriter fuser;
1271 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1272
1273 // Simplify new `convert`'s that may be added to the graph.
1274 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1275 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1276
1277 SCOPED_TRACE(m->ToString());
1278 const HloInstruction* conv;
1279 ASSERT_THAT(
1280 m->entry_computation()->root_instruction(),
1281 GmockMatch(m::GetTupleElement(
1282 m::CustomCall(
1283 &conv, kCudnnConvBiasActivationForwardCallTarget,
1284 m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
1285 .WithElementType(F16),
1286 m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
1287 .WithElementType(F16),
1288 m::Parameter(2), m::Reshape(m::Parameter(3))),
1289 0)
1290 .WithShape(F16, {1, 32, 9, 9})));
1291 TF_ASSERT_OK_AND_ASSIGN(auto config,
1292 conv->backend_config<CudnnConvBackendConfig>());
1293 EXPECT_EQ(config.side_input_scale(), 1);
1294 }
1295
TEST_F(CudnnFusedConvRewriterHloTest,NoStrengthReduceF32ToF16IfBiasIsF32)1296 TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
1297 const std::string module_str = R"(
1298 HloModule Test
1299
1300 ENTRY Test {
1301 inputs = f16[1,17,9,9] parameter(0)
1302 filters = f16[3,3,17,32] parameter(1)
1303 bias = f32[32] parameter(2)
1304 side_input = f16[1,32,9,9] parameter(3)
1305
1306 inputs_f32 = f32[1,17,9,9] convert(inputs)
1307 filters_f32 = f32[3,3,17,32] convert(filters)
1308 bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
1309 side_input_f32 = f32[1,32,9,9] convert(side_input)
1310 conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
1311 window={size=3x3 pad=1_1x1_1},
1312 dim_labels=bf01_01io->bf01
1313 sum = add(conv, side_input_f32)
1314 sum2 = add(sum, bias_broadcast)
1315 ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
1316 })";
1317 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1318
1319 GpuConvRewriter rewriter;
1320 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1321 CudnnFusedConvRewriter fuser;
1322 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1323
1324 // Simplify new `convert`'s that may be added to the graph.
1325 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1326 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1327
1328 SCOPED_TRACE(m->ToString());
1329 const HloInstruction* conv;
1330 // fp16 convs only support fp16 biases. Because bias is fp32, it doesn't get
1331 // fused in, and we get an fp32 conv.
1332 ASSERT_THAT(
1333 m->entry_computation()->root_instruction(),
1334 GmockMatch(
1335 m::Convert(m::GetTupleElement(
1336 m::CustomCall(
1337 &conv, kCudnnConvBiasActivationForwardCallTarget,
1338 m::Convert(m::Parameter(0)).WithElementType(F32),
1339 m::Convert(m::Parameter(1)).WithElementType(F32),
1340 m::Parameter(2),
1341 m::Convert(m::Parameter(3)).WithElementType(F32)),
1342 0))
1343 .WithShape(F16, {1, 32, 9, 9})));
1344 TF_ASSERT_OK_AND_ASSIGN(auto config,
1345 conv->backend_config<CudnnConvBackendConfig>());
1346 EXPECT_EQ(config.side_input_scale(), 1);
1347 }
1348
TEST_F(CudnnFusedConvRewriterHloTest,F32Constants)1349 TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
1350 const std::string module_str = R"(
1351 HloModule Test
1352
1353 ENTRY Test {
1354 inputs = f16[1,2,2,2] parameter(0)
1355 filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
1356 bias = f16[2] parameter(1)
1357 bias_f32 = f32[2] convert(bias)
1358 side_input_f32 = f32[1,2,2,2] constant({{
1359 {{0.5, 0.25}, {0.125, 0.0625}},
1360 {{0.5, 0.25}, {0.125, 0.0625}}
1361 }})
1362
1363 inputs_f32 = f32[1,2,2,2] convert(inputs)
1364 bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
1365 conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
1366 window={size=1x1}, dim_labels=bf01_01io->bf01
1367 sum = add(conv, side_input_f32)
1368 sum2 = add(sum, bias_broadcast)
1369 ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
1370 })";
1371 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1372
1373 GpuConvRewriter rewriter;
1374 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1375 CudnnFusedConvRewriter fuser;
1376 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1377
1378 // Simplify new `convert`'s that may be added to the graph, and fold
1379 // convert back into constants.
1380 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1381 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1382 HloConstantFolding constant_folding;
1383 TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
1384
1385 SCOPED_TRACE(m->ToString());
1386 const HloInstruction* conv;
1387 ASSERT_THAT(
1388 m->entry_computation()->root_instruction(),
1389 GmockMatch(m::GetTupleElement(
1390 m::CustomCall(
1391 &conv, kCudnnConvBiasActivationForwardCallTarget,
1392 m::Parameter(0), m::Constant().WithElementType(F16),
1393 m::Parameter(1), m::Constant().WithElementType(F16)),
1394 0)
1395 .WithShape(F16, {1, 2, 2, 2})));
1396 TF_ASSERT_OK_AND_ASSIGN(auto config,
1397 conv->backend_config<CudnnConvBackendConfig>());
1398 EXPECT_EQ(config.side_input_scale(), 1);
1399 }
1400
TEST_F(CudnnFusedConvRewriterHloTest,F32ConstantsNotLosslesslyConvertible)1401 TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
1402 const std::string module_str = R"(
1403 HloModule Test
1404
1405 ENTRY Test {
1406 inputs = f16[1,2,2,2] parameter(0)
1407 filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
1408 bias = f16[2] parameter(1)
1409 bias_f32 = f32[2] convert(bias)
1410 side_input_f32 = f32[1,2,2,2] constant({{
1411 {{0.1, 0.2}, {0.3, 0.4}},
1412 {{0.5, 0.6}, {0.7, 0.8}}
1413 }})
1414
1415 inputs_f32 = f32[1,2,2,2] convert(inputs)
1416 bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
1417 conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
1418 window={size=1x1}, dim_labels=bf01_01io->bf01
1419 sum = add(conv, side_input_f32)
1420 sum2 = add(sum, bias_broadcast)
1421 ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
1422 })";
1423 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1424
1425 GpuConvRewriter rewriter;
1426 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1427 CudnnFusedConvRewriter fuser;
1428 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1429
1430 // Simplify new `convert`'s that may be added to the graph, and fold
1431 // convert back into constants.
1432 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1433 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1434 HloConstantFolding constant_folding;
1435 TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
1436
1437 SCOPED_TRACE(m->ToString());
1438 const HloInstruction* conv;
1439 // This doesn't get transformed into an f16 conv because the filters param is
1440 // not losslessly expressible as f16.
1441 ASSERT_THAT(
1442 m->entry_computation()->root_instruction(),
1443 GmockMatch(
1444 m::Convert(m::GetTupleElement(
1445 m::CustomCall(
1446 &conv, kCudnnConvBiasActivationForwardCallTarget,
1447 m::Convert(m::Parameter(0)).WithElementType(F32),
1448 m::Constant().WithElementType(F32),
1449 m::Convert(m::Parameter(1)).WithElementType(F32),
1450 m::Constant().WithElementType(F32)),
1451 0)
1452 .WithShape(F32, {1, 2, 2, 2}))
1453 .WithElementType(F16)));
1454 TF_ASSERT_OK_AND_ASSIGN(auto config,
1455 conv->backend_config<CudnnConvBackendConfig>());
1456 EXPECT_EQ(config.side_input_scale(), 1);
1457 }
1458
TEST_F(CudnnFusedConvRewriterHloTest,FuseReluBeforeConvert)1459 TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) {
1460 const std::string module_str = R"(
1461 HloModule Test
1462
1463 ENTRY Test {
1464 input = s8[1,17,9,9] parameter(0)
1465 filter = s8[3,3,17,32] parameter(1)
1466 inputs32 = s32[1,17,9,9] convert(input)
1467 filters32 = s32[3,3,17,32] convert(filter)
1468
1469 conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
1470
1471 zero = s32[] constant(0)
1472 zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
1473 relu = maximum(conv, zeros)
1474
1475 lower = s32[] constant(-128)
1476 lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
1477 upper = s32[] constant(127)
1478 uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
1479
1480 clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
1481
1482 ROOT convert = s8[1,32,9,9] convert(clamp)
1483 })";
1484 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1485
1486 GpuConvRewriter rewriter;
1487 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1488 CudnnFusedConvRewriter fuser;
1489 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1490
1491 // Simplify new `convert`'s that may be added to the graph.
1492 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1493 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1494
1495 SCOPED_TRACE(m->ToString());
1496 const HloInstruction* conv;
1497 ASSERT_THAT(
1498 m->entry_computation()->root_instruction(),
1499 GmockMatch(
1500 m::GetTupleElement(
1501 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1502 m::Parameter(0), //
1503 m::Parameter(1), //
1504 m::Broadcast(m::ConstantEffectiveScalar(0))
1505 .WithShape(F32, {32})),
1506 0)
1507 .WithShape(S8, {1, 32, 9, 9})));
1508 TF_ASSERT_OK_AND_ASSIGN(auto config,
1509 conv->backend_config<CudnnConvBackendConfig>());
1510 EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
1511 }
1512
TEST_F(CudnnFusedConvRewriterHloTest,BiasTypeMatchesConvTypeIfFp)1513 TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) {
1514 const std::string module_str = R"(
1515 HloModule Test
1516
1517 ENTRY Test {
1518 input = f64[1,17,9,9] parameter(0)
1519 filter = f64[3,3,17,32] parameter(1)
1520 bias = f64[1,32,9,9] broadcast(f64[32] convert(f32[32] parameter(2))), dimensions={1}
1521 conv = f64[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
1522 ROOT root = f64[1,32,9,9] add(conv, bias)
1523 })";
1524 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1525
1526 GpuConvRewriter rewriter;
1527 TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1528 CudnnFusedConvRewriter fuser;
1529 TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1530
1531 // Simplify new `convert`'s that may be added to the graph.
1532 AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1533 TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1534
1535 SCOPED_TRACE(m->ToString());
1536 const HloInstruction* conv;
1537 ASSERT_THAT(
1538 m->entry_computation()->root_instruction(),
1539 GmockMatch(
1540 m::GetTupleElement(
1541 m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1542 m::Parameter(0), //
1543 m::Parameter(1), //
1544 m::Convert(m::Parameter(2)).WithShape(F64, {32})),
1545 0)
1546 .WithShape(F64, {1, 32, 9, 9})));
1547 }
1548
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvInt8ToInt8)1549 TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
1550 // clamp(max(0, conv(x, w)+bias)); for int8_t
1551 TestClamp(
1552 // pre_hlo
1553 R"(
1554 HloModule Test
1555
1556 ENTRY Test {
1557 zero = f32[] constant(0)
1558 zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1559
1560 input = s8[1,3,3,64] parameter(0)
1561 filter = s8[3,3,64,64] parameter(1)
1562 bias = f32[64] parameter(2)
1563
1564 inputs32 = s32[1,3,3,64] convert(input)
1565 filters32 = s32[3,3,64,64] convert(filter)
1566
1567 conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1568
1569 convfloat = f32[1,3,3,64] convert(conv)
1570 broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1571 add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
1572 relu = f32[1,3,3,64] maximum(zeros, add1)
1573
1574 lower = f32[] constant(-128)
1575 lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1576 upper = f32[] constant(127)
1577 uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1578
1579 clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1580
1581 ROOT convert = s8[1,3,3,64] convert(clamp)
1582 })",
1583 // post_hlo
1584 R"(
1585 ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> s8[1,3,3,64]
1586 ; CHECK: %cudnn-conv-bias-activation{{(\.[0-9])?}} = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1587 ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9])?}}), index=0
1588 )");
1589 }
1590
1591 // Disabled per b/190854862 or nvbugs/3326122.
TEST_F(CudnnFusedConvRewriterTest,DISABLED_TestFusedConvInt8ToFloat)1592 TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) {
1593 // max(0, convert<float>(conv<int32_t>(int8_x),
1594 // conv<int32_t>(int8_w))+float_bias)); int8_t to float via bias.
1595 TestClamp(
1596 // pre_hlo
1597 R"(
1598 HloModule Test
1599
1600 ENTRY Test {
1601 zero = f32[] constant(0)
1602 zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1603
1604 input = s8[1,3,3,64] parameter(0)
1605 filter = s8[3,3,64,64] parameter(1)
1606 bias = f32[64] parameter(2)
1607
1608 inputs32 = s32[1,3,3,64] convert(input)
1609 filters32 = s32[3,3,64,64] convert(filter)
1610
1611 conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1612
1613 convfloat = f32[1,3,3,64] convert(conv)
1614 broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1615 add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
1616 ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
1617 })",
1618 // post_hlo
1619 R"(
1620 ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> f32[1,3,3,64] {
1621 ; CHECK: %custom-call{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1622 ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9])?}} = f32[1,3,3,64]{3,2,1,0} get-tuple-element(%custom-call{{(\.[0-9])?}}), index=0
1623 )");
1624 }
1625
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8)1626 TEST_F(CudnnFusedConvRewriterTest,
1627 TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) {
1628 // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
1629 // convert<int32_t>(int8_side_input) + bias)); for int8_t
1630 TestClamp(
1631 // pre_hlo
1632 R"(
1633 HloModule Test
1634
1635 ENTRY Test {
1636 zero = f32[] constant(0)
1637 zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1638 alpha_conv_scalar = f32[] constant(0.999994934)
1639 alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
1640 alpha_side_input_scalar = f32[] constant(0.899994934)
1641 alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
1642
1643 input = s8[1,3,3,64] parameter(0)
1644 filter = s8[3,3,64,64] parameter(1)
1645 side_input = s8[1,3,3,64] parameter(2)
1646 bias = f32[64] parameter(3)
1647
1648 inputs32 = s32[1,3,3,64] convert(input)
1649 filters32 = s32[3,3,64,64] convert(filter)
1650 side_input_f32 = f32[1,3,3,64] convert(side_input)
1651
1652 conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1653
1654 convfloat = f32[1,3,3,64] convert(conv)
1655 scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
1656 scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
1657 broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1658 add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
1659 add2 = f32[1,3,3,64] add(add1, scaled_side_input)
1660 relu = f32[1,3,3,64] maximum(zeros, add2)
1661
1662 lower = f32[] constant(-128)
1663 lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1664 upper = f32[] constant(127)
1665 uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1666
1667 clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1668
1669 ROOT convert = s8[1,3,3,64] convert(clamp)
1670 })",
1671 // post_hlo
1672 R"(
1673 ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] {
1674 ; CHECK: %cudnn-conv-bias-activation{{(\.[0-9]+)?}} =
1675 ; CHECK-SAME: (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0})
1676 ; CHECK-SAME: custom-call(%input, %copy{{(\.[0-9]+)?}}, %bias, %side_input),
1677 ; CHECK-SAME: window={size=3x3 pad=1_1x1_1},
1678 ; CHECK-SAME: dim_labels=b01f_01io->b01f,
1679 ; CHECK-SAME: custom_call_target="__cudnn$convBiasActivationForward",
1680 ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9]+)?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9]+)?}}), index=0
1681 )");
1682 }
1683
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8)1684 TEST_F(CudnnFusedConvRewriterTest,
1685 TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) {
1686 // From:
1687 // convert<int8_t>(clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
1688 // float_side_input + bias))); To: convert<int8_t>(clamp(conv(int8_x, int8_w,
1689 // float_alpha_side, float_side_input, float_bias)));
1690 TestClamp(
1691 // pre_hlo
1692 R"(
1693 HloModule Test
1694
1695 ENTRY Test {
1696 zero = f32[] constant(0)
1697 zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1698 alpha_conv_scalar = f32[] constant(0.999994934)
1699 alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
1700 alpha_side_input_scalar = f32[] constant(0.899994934)
1701 alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
1702
1703 input = s8[1,3,3,64] parameter(0)
1704 filter = s8[3,3,64,64] parameter(1)
1705 side_input = f32[1,3,3,64] parameter(2)
1706 bias = f32[64] parameter(3)
1707
1708 inputs32 = s32[1,3,3,64] convert(input)
1709 filters32 = s32[3,3,64,64] convert(filter)
1710
1711 conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1712
1713 convfloat = f32[1,3,3,64] convert(conv)
1714 scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
1715 scaled_side_input = f32[1,3,3,64] multiply(side_input, alpha_side_input)
1716 broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1717 add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
1718 add2 = f32[1,3,3,64] add(add1, scaled_side_input)
1719 relu = f32[1,3,3,64] maximum(zeros, add2)
1720
1721 lower = f32[] constant(-128)
1722 lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1723 upper = f32[] constant(127)
1724 uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1725
1726 clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1727
1728 ROOT convert = s8[1,3,3,64] convert(clamp)
1729 })",
1730 // post_hlo
1731 R"(
1732 ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: f32[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] {
1733 ; CHECK: %cudnn-conv-bias-activation{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1734 ; CHECK: ROOT %fusion = s8[1,3,3,64]{3,2,1,0} fusion(%get-tuple-element{{(\.[0-9])?}}), kind=kLoop, calls=%fused_computation
1735 )");
1736 }
1737
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat)1738 TEST_F(CudnnFusedConvRewriterTest,
1739 TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) {
1740 // From:
1741 // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
1742 // convert<float>(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w,
1743 // float_alpha_side, convert<float>(int8_side_input), float_bias));
1744 TestClamp(
1745 // pre_hlo
1746 R"(
1747 HloModule Test
1748
1749 ENTRY Test {
1750 zero = f32[] constant(0)
1751 zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1752 alpha_conv_scalar = f32[] constant(0.999994934)
1753 alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
1754 alpha_side_input_scalar = f32[] constant(0.899994934)
1755 alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
1756
1757 input = s8[1,3,3,64] parameter(0)
1758 filter = s8[3,3,64,64] parameter(1)
1759 side_input = s8[1,3,3,64] parameter(2)
1760 bias = f32[64] parameter(3)
1761
1762 inputs32 = s32[1,3,3,64] convert(input)
1763 filters32 = s32[3,3,64,64] convert(filter)
1764 side_input_f32 = f32[1,3,3,64] convert(side_input)
1765
1766 conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1767
1768 convfloat = f32[1,3,3,64] convert(conv)
1769 scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
1770 scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
1771 broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1772 add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
1773 add2 = f32[1,3,3,64] add(add1, scaled_side_input)
1774 relu = f32[1,3,3,64] maximum(zeros, add2)
1775
1776 lower = f32[] constant(-128)
1777 lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1778 upper = f32[] constant(127)
1779 uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1780
1781 ROOT clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1782 })",
1783 // post_hlo
1784 R"(
1785 ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> f32[1,3,3,64] {
1786 ; CHECK: %side_input_f32 = f32[1,3,3,64]{3,2,1,0} convert(%side_input)
1787 ; CHECK: %cudnn-conv-bias-activation{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1788 ; CHECK: ROOT %fusion = f32[1,3,3,64]{3,2,1,0} fusion(%get-tuple-element{{(\.[0-9])?}}), kind=kLoop, calls=%fused_computation
1789 )");
1790 }
1791
TEST_F(CudnnFusedConvRewriterTest,TestConvInt8ToInt8NoClamp)1792 TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) {
1793 // Check that integer convolution without clamp to int8_t is not allowed.
1794 // convert<int8_t>(custom_call<int32_t>(int32_x, int32_w,
1795 // cudnnConvolutionForward))
1796 const std::string module_str = absl::StrFormat(R"(
1797 HloModule Test
1798
1799 ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
1800 zero = s8[] constant(0)
1801 zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
1802 input = s8[1,17,9,9]{3,2,1,0} parameter(0)
1803 filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
1804 custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
1805 get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
1806 convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
1807 ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
1808 })");
1809 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1810
1811 ASSERT_FALSE(CudnnFusedConvRewriter().Run(m.get()).ok());
1812 }
1813
1814 TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) {
1815 // Although bias and so on are fused with forward convolution,
1816 // it is still not allowed if the output is not clampped/converted to int8_t
1817 // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for
1818 // int8_t
1819
1820 const std::string module_str = absl::StrFormat(R"(
1821 HloModule Test
1822
1823 ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
1824 zero = s8[] constant(0)
1825 zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
1826 input = s8[1,17,9,9]{3,2,1,0} parameter(0)
1827 filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
1828 custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
1829 get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
1830 convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
1831 ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
1832 })");
1833 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1834
1835 ASSERT_FALSE(CudnnFusedConvRewriter().Run(m.get()).ok());
1836 }
1837
1838 } // namespace
1839 } // namespace gpu
1840 } // namespace xla
1841