• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_replace.h"
21 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
22 #include "tensorflow/compiler/xla/service/convert_mover.h"
23 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
24 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
25 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
27 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
28 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
31 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
32 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
34 #include "tensorflow/compiler/xla/service/reshape_mover.h"
35 #include "tensorflow/compiler/xla/test_helpers.h"
36 #include "tensorflow/compiler/xla/tests/filecheck.h"
37 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/platform/test.h"
40 
41 namespace xla {
42 namespace gpu {
43 namespace {
44 
45 // TODO(b/210165681): The tests in this file are fragile to HLO op names.
46 
47 namespace m = match;
48 
49 using ::testing::HasSubstr;
50 using ::testing::Not;
51 
52 class CudnnFusedConvRewriterHloTest : public HloTestBase {
53  public:
CudnnFusedConvRewriterHloTest()54   CudnnFusedConvRewriterHloTest()
55       : HloTestBase(/*verifier_layout_sensitive=*/false,
56                     /*allow_mixed_precision_in_hlo_verifier=*/false,
57                     /*instruction_can_change_layout_func=*/{}) {}
58 };
59 
60 class CudnnFusedConvRewriterTest : public GpuCodegenTest {
61  protected:
GetOptimizedHlo(absl::string_view hlo_string)62   std::string GetOptimizedHlo(absl::string_view hlo_string) {
63     // cudnn_vectorize_convolutions transforms convolutions, making it hard to
64     // match them here in this test.  What's worse, the transforms it does
65     // depends on the GPU that's available!  So just disable them for this
66     // function that gets the optimized HLO.  When we actually run the module
67     // we'll still have this pass enabled.
68     HloModuleConfig config = GetModuleConfigForTest();
69     DebugOptions debug_opts = config.debug_options();
70     debug_opts.add_xla_disable_hlo_passes("cudnn_vectorize_convolutions");
71     config.set_debug_options(debug_opts);
72 
73     auto result = backend().compiler()->RunHloPasses(
74         ParseAndReturnVerifiedModule(hlo_string, config).value(),
75         backend().default_stream_executor(), backend().memory_allocator());
76     if (!result.status().ok()) {
77       TF_EXPECT_OK(result.status())
78           << "HLO compilation failed: " << result.status();
79       return "";
80     }
81     HloPrintOptions print_opts;
82     print_opts.set_print_operand_shape(false);
83     return (*result)->ToString(print_opts);
84   }
85 
TestMatchWithAllTypes(absl::string_view hlo_string)86   void TestMatchWithAllTypes(absl::string_view hlo_string) {
87     for (absl::string_view type : {"f16", "f32", "f64"}) {
88       const std::string hlo_with_new_type =
89           absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
90       std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
91       EXPECT_THAT(optimized_hlo_string,
92                   Not(HasSubstr(kCudnnConvForwardCallTarget)))
93           << optimized_hlo_string;
94       EXPECT_THAT(optimized_hlo_string,
95                   HasSubstr(kCudnnConvBiasActivationForwardCallTarget));
96       EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
97           << optimized_hlo_string;
98     }
99   }
100 
TestClamp(absl::string_view pre_hlo_string,absl::string_view post_hlo_string)101   void TestClamp(absl::string_view pre_hlo_string,
102                  absl::string_view post_hlo_string) {
103     std::string alpha_conv_scalar, alpha_side_input_scalar;
104     std::string elementwise_type;
105 
106     std::string optimized_hlo_string = GetOptimizedHlo(pre_hlo_string);
107     EXPECT_THAT(optimized_hlo_string, Not(HasSubstr("Convert")));
108     EXPECT_THAT(optimized_hlo_string, HasSubstr("__cudnn$conv"));
109     EXPECT_TRUE(RunAndCompare(pre_hlo_string, ErrorSpec{0.01}))
110         << pre_hlo_string;
111 
112     StatusOr<bool> filecheck_result =
113         RunFileCheck(optimized_hlo_string, post_hlo_string);
114     ASSERT_TRUE(filecheck_result.ok()) << filecheck_result.status();
115     EXPECT_TRUE(*filecheck_result);
116   }
117 
TestNotMatchWithAllTypes(absl::string_view hlo_string)118   void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
119     for (absl::string_view type : {"f16", "f32", "f64"}) {
120       const std::string hlo_with_new_type =
121           absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
122       std::string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
123       SCOPED_TRACE(optimized_hlo_string);
124       EXPECT_THAT(optimized_hlo_string, HasSubstr(kCudnnConvForwardCallTarget));
125       EXPECT_THAT(optimized_hlo_string,
126                   Not(HasSubstr(kCudnnConvBiasActivationForwardCallTarget)));
127     }
128   }
129 };
130 
TEST_F(CudnnFusedConvRewriterTest,TestConvOnly)131 TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) {
132   // max(0, conv(x, w));
133   TestMatchWithAllTypes(R"(
134     HloModule Test
135 
136     ENTRY Test {
137       zero = TYPE[] constant(0)
138       zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
139 
140       input = TYPE[1,17,9,9] parameter(0)
141       filter = TYPE[3,3,17,32] parameter(1)
142 
143       conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
144       ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
145     })");
146 }
147 
TEST_F(CudnnFusedConvRewriterTest,TestBias)148 TEST_F(CudnnFusedConvRewriterTest, TestBias) {
149   // max(0, conv(x, w) + bias);
150   TestMatchWithAllTypes(R"(
151     HloModule Test
152 
153     ENTRY Test {
154       zero = TYPE[] constant(0)
155       zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
156 
157       input = TYPE[1,3,3,64] parameter(0)
158       filter = TYPE[3,3,64,64] parameter(1)
159       bias = TYPE[64] parameter(2)
160 
161       conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
162       broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
163       add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
164       ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
165     })");
166 }
167 
TEST_F(CudnnFusedConvRewriterTest,TestSideInputOnly)168 TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) {
169   // max(0, conv(x, w) + side_input);
170   TestMatchWithAllTypes(R"(
171     HloModule Test
172 
173     ENTRY Test {
174       zero = TYPE[] constant(0)
175       zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
176 
177       input = TYPE[1,3,3,64] parameter(0)
178       filter = TYPE[3,3,64,64] parameter(1)
179       side_input = TYPE[1,3,3,64] parameter(2)
180 
181       conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
182       add1 = TYPE[1,3,3,64] add(conv, side_input)
183       ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
184     })");
185 }
186 
TEST_F(CudnnFusedConvRewriterTest,TestBiasAndSideInput)187 TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) {
188   // max(0, conv(x, w) + side_input + bias);
189   TestMatchWithAllTypes(R"(
190     HloModule Test
191 
192     ENTRY Test {
193       zero = TYPE[] constant(0)
194       zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
195 
196       input = TYPE[1,3,3,64] parameter(0)
197       filter = TYPE[3,3,64,64] parameter(1)
198       side_input = TYPE[1,3,3,64] parameter(2)
199       bias = TYPE[64] parameter(3)
200 
201       conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
202       broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
203       add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
204       add2 = TYPE[1,3,3,64] add(add1, side_input)
205       ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
206     })");
207 }
208 
TEST_F(CudnnFusedConvRewriterTest,TestScaledConv)209 TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) {
210   // max(0, 0.999994934 * conv(x, w));
211   TestMatchWithAllTypes(R"(
212     HloModule Test
213 
214     ENTRY Test {
215       zero = TYPE[] constant(0)
216       zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
217       alpha_conv_scalar = TYPE[] constant(0.999994934)
218 
219       input = TYPE[1,17,9,9] parameter(0)
220       filter = TYPE[3,3,17,32] parameter(1)
221 
222       conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
223       alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
224       scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
225       ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
226     })");
227 }
228 
TEST_F(CudnnFusedConvRewriterTest,TestNoCrashOnInf)229 TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) {
230   EXPECT_TRUE(RunAndCompare(R"(
231     HloModule Test
232 
233     ENTRY Test {
234       zero = f32[] constant(inf)
235       zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
236       alpha_conv_scalar = f32[] constant(0.999994934)
237 
238       input = f32[1,17,9,9] parameter(0)
239       filter = f32[3,3,17,32] parameter(1)
240 
241       conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
242       alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
243       scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv)
244       ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv)
245     })",
246                             ErrorSpec{0.01}));
247 }
248 
TEST_F(CudnnFusedConvRewriterTest,TestScaledConvAndSideInput)249 TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) {
250   // max(0, conv(x, w) + 0.899994934 * side_input);
251   TestMatchWithAllTypes(R"(
252     HloModule Test
253 
254     ENTRY Test {
255       zero = TYPE[] constant(0)
256       zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
257       alpha_side_input_scalar = TYPE[] constant(0.899994934)
258       alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
259 
260       input = TYPE[1,3,3,64] parameter(0)
261       filter = TYPE[3,3,64,64] parameter(1)
262       side_input = TYPE[1,3,3,64] parameter(2)
263 
264       conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
265       scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
266       add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
267       ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
268     })");
269 }
270 
TEST_F(CudnnFusedConvRewriterTest,TestScaledConvAndScaledSideInput)271 TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) {
272   // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
273   TestMatchWithAllTypes(R"(
274     HloModule Test
275 
276     ENTRY Test {
277       zero = TYPE[] constant(0)
278       zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
279       alpha_conv_scalar = TYPE[] constant(0.999994934)
280       alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
281       alpha_side_input_scalar = TYPE[] constant(0.899994934)
282       alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
283 
284       input = TYPE[1,3,3,64] parameter(0)
285       filter = TYPE[3,3,64,64] parameter(1)
286       side_input = TYPE[1,3,3,64] parameter(2)
287 
288       conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
289       scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
290       scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
291       add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
292       ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
293     })");
294 }
295 
TEST_F(CudnnFusedConvRewriterTest,TestScaledConvAndScaledSideInputWithBias)296 TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) {
297   // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
298   TestMatchWithAllTypes(R"(
299     HloModule Test
300 
301     ENTRY Test {
302       zero = TYPE[] constant(0)
303       zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
304       alpha_conv_scalar = TYPE[] constant(0.999994934)
305       alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
306       alpha_side_input_scalar = TYPE[] constant(0.899994934)
307       alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
308 
309       input = TYPE[1,3,3,64] parameter(0)
310       filter = TYPE[3,3,64,64] parameter(1)
311       side_input = TYPE[1,3,3,64] parameter(2)
312       bias = TYPE[64] parameter(3)
313 
314       conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
315       scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
316       scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
317       broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
318       add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
319       add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
320       ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
321     })");
322 }
323 
TEST_F(CudnnFusedConvRewriterTest,TestMatchMaxZeroOnly)324 TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) {
325   // max(0.1, conv(x, w)) shouldn't match.
326   TestNotMatchWithAllTypes(R"(
327     HloModule Test
328 
329     ENTRY Test {
330       point_one = TYPE[] constant(0.1)
331       point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
332 
333       input = TYPE[1,17,9,9] parameter(0)
334       filter = TYPE[3,3,17,32] parameter(1)
335 
336       conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
337       ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
338     })");
339 }
340 
TEST_F(CudnnFusedConvRewriterTest,PreservesMetadata)341 TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) {
342   const char* kHloString = R"(
343     HloModule Test
344 
345     ENTRY Test {
346       zero = f32[] constant(0)
347       zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
348 
349       input = f32[1,17,9,9] parameter(0)
350       filter = f32[3,3,17,32] parameter(1)
351 
352       conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1, metadata={op_type="foo" op_name="bar"}
353       ROOT relu = f32[1,32,9,9] maximum(zeros, conv)
354     })";
355 
356   const std::string optimized_hlo_string =
357       backend()
358           .compiler()
359           ->RunHloPasses(
360               ParseAndReturnVerifiedModule(kHloString, GetModuleConfigForTest())
361                   .value(),
362               backend().default_stream_executor(), backend().memory_allocator())
363           .value()
364           ->ToString();
365   EXPECT_THAT(optimized_hlo_string,
366               ::testing::ContainsRegex(
367                   R"(custom-call.*metadata=\{op_type="foo" op_name="bar"\})"));
368 }
369 
TEST_F(CudnnFusedConvRewriterTest,TestPreservesFeatureGroupCount)370 TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) {
371   // The convolution below would crash if feature_count is not preserved.
372   const char* kHloString = R"(
373     HloModule jaxpr_computation__6.19
374 
375     primitive_computation__1.4 {
376       parameter.5 = f32[] parameter(0)
377       parameter.6 = f32[] parameter(1)
378       ROOT add.7 = f32[] add(parameter.5, parameter.6)
379     }
380 
381     ENTRY jaxpr_computation__7.8 {
382       parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1)
383       parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0)
384       convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53
385       constant.13 = f32[] constant(0)
386       broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={}
387       maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14)
388       ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4
389     }
390   )";
391   EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01}));
392 }
393 
TEST_F(CudnnFusedConvRewriterTest,TestConvInt8ToInt8)394 TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8) {
395   // max(0, clamp(conv(x, w)))); for int8_t
396   TestClamp(
397       // pre_hlo
398       R"(
399     HloModule Test
400 
401     ENTRY Test {
402       zero = s8[] constant(0)
403       zeros = s8[1,32,9,9] broadcast(zero), dimensions={}
404 
405       input = s8[1,17,9,9] parameter(0)
406       filter = s8[3,3,17,32] parameter(1)
407 
408       inputs32 = s32[1,17,9,9] convert(input)
409       filters32 = s32[3,3,17,32] convert(filter)
410 
411       conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
412 
413       lower = s32[] constant(-128)
414       lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
415       upper = s32[] constant(127)
416       uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
417 
418       clamp = s32[1,32,9,9] clamp(lowers, conv, uppers)
419 
420       ROOT convert = s8[1,32,9,9] convert(clamp)
421     })",
422       // post_hlo
423       R"(
424       ; CHECK-LABEL: ENTRY %Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
425       ; CHECK:  %cudnn-conv{{(\.[0-9])?}} = (s8[1,32,9,9]{1,3,2,0}, u8[{{[0-9]*}}]{0}) custom-call(%fusion{{(\.[0-9])?}}, %fusion{{(\.[0-9])?}}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config=
426       )");
427 }
428 
TEST_F(CudnnFusedConvRewriterHloTest,TestConvInt8ToFloat)429 TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) {
430   const std::string module_str = R"(
431     HloModule Test
432 
433     ENTRY Test {
434       input = s8[1,17,9,9] parameter(0)
435       filter = s8[3,3,17,32] parameter(1)
436 
437       inputs32 = s32[1,17,9,9] convert(input)
438       filters32 = s32[3,3,17,32] convert(filter)
439 
440       conv = s32[1,32,9,9] convolution(inputs32, filters32),
441                window={size=3x3 pad=1_1x1_1},
442                dim_labels=bf01_01io->bf01
443 
444       ROOT convert = f32[1,32,9,9] convert(conv)
445     })";
446   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
447 
448   GpuConvRewriter rewriter;
449   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
450   CudnnFusedConvRewriter fuser;
451   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
452 
453   SCOPED_TRACE(m->ToString());
454   EXPECT_THAT(m->entry_computation()->root_instruction(),
455               GmockMatch(m::GetTupleElement(
456                              m::CustomCall(kCudnnConvForwardCallTarget), 0)
457                              .WithShape(F32, {1, 32, 9, 9})));
458 }
459 
TEST_F(CudnnFusedConvRewriterHloTest,TestConvInt8ToInt8BiasSideInput)460 TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) {
461   const std::string module_str = R"(
462     HloModule Test
463 
464     ENTRY Test {
465       input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
466       filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
467       bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
468       side_input = f32[1,32,9,9] convert(s8[1,32,9,9] parameter(3))
469 
470       conv = s32[1,32,9,9] convolution(input, filter),
471                window={size=3x3 pad=1_1x1_1},
472                dim_labels=bf01_01io->bf01
473       conv_f32 = f32[1,32,9,9] convert(conv)
474       ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
475                                              add(add(conv_f32, bias), side_input),
476                                              f32[1,32,9,9] broadcast(f32[] constant(127))))
477     })";
478   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
479 
480   GpuConvRewriter rewriter;
481   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
482   CudnnFusedConvRewriter fuser;
483   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
484 
485   // Simplify new `convert`'s that may be added to the graph.
486   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
487   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
488 
489   SCOPED_TRACE(m->ToString());
490   EXPECT_THAT(
491       m->entry_computation()->root_instruction(),
492       GmockMatch(m::GetTupleElement(
493                      m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
494                                    m::Parameter(0), m::Parameter(1),
495                                    m::Parameter(2), m::Parameter(3)),
496                      0)
497                      .WithShape(S8, {1, 32, 9, 9})));
498 }
499 
TEST_F(CudnnFusedConvRewriterHloTest,TestReluAfterConvert)500 TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) {
501   const std::string module_str = R"(
502     HloModule Test
503 
504     ENTRY Test {
505       input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
506       filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
507 
508       conv = s32[1,32,9,9] convolution(input, filter),
509                window={size=3x3 pad=1_1x1_1},
510                dim_labels=bf01_01io->bf01
511       conv_s8 = s8[1,32,9,9] convert(clamp(s32[1,32,9,9] broadcast(s32[] constant(-128)),
512                                            conv,
513                                            s32[1,32,9,9] broadcast(s32[] constant(127))))
514       zeros = s8[1,32,9,9] broadcast(s8[] constant(0)), dimensions={}
515       ROOT root = s8[1,32,9,9] maximum(conv_s8, zeros)
516     })";
517   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
518 
519   GpuConvRewriter rewriter;
520   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
521   CudnnFusedConvRewriter fuser;
522   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
523 
524   // Simplify new `convert`'s that may be added to the graph.
525   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
526   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
527 
528   SCOPED_TRACE(m->ToString());
529   const HloInstruction* conv;
530   ASSERT_THAT(
531       m->entry_computation()->root_instruction(),
532       GmockMatch(
533           m::GetTupleElement(
534               m::CustomCall(
535                   &conv, kCudnnConvBiasActivationForwardCallTarget,
536                   m::Parameter(0),  //
537                   m::Parameter(1),  //
538                   m::Broadcast(
539                       m::ConstantEffectiveScalar(0).WithElementType(F32))),
540               0)
541               .WithShape(S8, {1, 32, 9, 9})));
542   TF_ASSERT_OK_AND_ASSIGN(auto config,
543                           conv->backend_config<CudnnConvBackendConfig>());
544   EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
545 }
546 
TEST_F(CudnnFusedConvRewriterHloTest,TestConvInt8ToFloatBiasSideInput)547 TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) {
548   const std::string module_str = R"(
549     HloModule Test
550 
551     ENTRY Test {
552       input = s8[1,17,9,9] parameter(0)
553       filter = s8[3,3,17,32] parameter(1)
554       bias = f32[32] parameter(2)
555       bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
556       side_input_f32 = f32[1,32,9,9] parameter(3)
557 
558       inputs32 = s32[1,17,9,9] convert(input)
559       filters32 = s32[3,3,17,32] convert(filter)
560 
561       conv = s32[1,32,9,9] convolution(inputs32, filters32),
562                window={size=3x3 pad=1_1x1_1},
563                dim_labels=bf01_01io->bf01
564       conv_f32 = f32[1,32,9,9] convert(conv)
565       sum1 = add(conv_f32, bias_broadcast)
566       ROOT sum2 = add(sum1, side_input_f32)
567     })";
568   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
569 
570   GpuConvRewriter rewriter;
571   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
572   CudnnFusedConvRewriter fuser;
573   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
574 
575   // Simplify new `convert`'s that may be added to the graph.
576   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
577   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
578 
579   SCOPED_TRACE(m->ToString());
580   EXPECT_THAT(
581       m->entry_computation()->root_instruction(),
582       GmockMatch(m::GetTupleElement(
583                      m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
584                                    m::Parameter(0), m::Parameter(1),
585                                    m::Parameter(2), m::Parameter(3)),
586                      0)
587                      .WithShape(F32, {1, 32, 9, 9})));
588 }
589 
590 // The ReshapeMover pass changes
591 //   reshape(side_input) * alpha -->
592 //   reshape(side_input * alpha).
593 // Make sure we can pattern-match this.
TEST_F(CudnnFusedConvRewriterHloTest,Int8SideInputWithScaleAndReshape)594 TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) {
595   const std::string module_str = R"(
596     HloModule Test
597 
598     ENTRY Test {
599       input = s32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
600       filter = s32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
601       bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
602       side_input_scale = f32[2592] broadcast(f32[] constant(0.25)), dimensions={}
603       side_input = f32[1,32,9,9] reshape(multiply(f32[2592] convert(s8[2592] parameter(3)), side_input_scale))
604 
605       conv = s32[1,32,9,9] convolution(input, filter),
606                window={size=3x3 pad=1_1x1_1},
607                dim_labels=bf01_01io->bf01
608       ROOT root = s8[1,32,9,9] convert(clamp(f32[1,32,9,9] broadcast(f32[] constant(-128)),
609                                              add(add(f32[1,32,9,9] convert(conv), bias), side_input),
610                                              f32[1,32,9,9] broadcast(f32[] constant(127))))
611     })";
612   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
613 
614   GpuConvRewriter rewriter;
615   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
616   CudnnFusedConvRewriter fuser;
617   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
618 
619   // Simplify new `convert`'s that may be added to the graph.
620   HloPassFix<HloPassPipeline> simplify("simplify");
621   simplify.AddPass<AlgebraicSimplifier>(AlgebraicSimplifierOptions{});
622   simplify.AddPass<ReshapeMover>();
623   simplify.AddPass<ConvertMover>();
624   TF_ASSERT_OK(RunHloPass(&simplify, m.get()).status());
625 
626   SCOPED_TRACE(m->ToString());
627   const HloInstruction* conv = nullptr;
628   ASSERT_THAT(
629       m->entry_computation()->root_instruction(),
630       GmockMatch(
631           m::GetTupleElement(
632               m::CustomCall(
633                   &conv, kCudnnConvBiasActivationForwardCallTarget,
634                   m::Parameter(0),  //
635                   m::Parameter(1),  //
636                   m::Parameter(2),  //
637                   m::Reshape(m::Parameter(3)).WithShape(S8, {1, 32, 9, 9})),
638               0)
639               .WithShape(S8, {1, 32, 9, 9})));
640   TF_ASSERT_OK_AND_ASSIGN(auto config,
641                           conv->backend_config<CudnnConvBackendConfig>());
642   EXPECT_EQ(config.conv_result_scale(), 1);
643   EXPECT_EQ(config.side_input_scale(), 0.25);
644 }
645 
TEST_F(CudnnFusedConvRewriterHloTest,FuseAlpha)646 TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) {
647   const std::string module_str = R"(
648     HloModule Test
649 
650     ENTRY Test {
651       input = s8[1,17,9,9] parameter(0)
652       filter = s8[3,3,17,32] parameter(1)
653       inputs32 = s32[1,17,9,9] convert(input)
654       filters32 = s32[3,3,17,32] convert(filter)
655       alpha = f32[] constant(42)
656       alpha_broadcast = f32[1,32,9,9] broadcast(alpha), dimensions={}
657 
658       conv = s32[1,32,9,9] convolution(inputs32, filters32),
659                window={size=3x3 pad=1_1x1_1},
660                dim_labels=bf01_01io->bf01
661       convert = f32[1,32,9,9] convert(conv)
662       ROOT root = multiply(convert, alpha_broadcast)
663     })";
664   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
665 
666   GpuConvRewriter rewriter;
667   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
668   CudnnFusedConvRewriter fuser;
669   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
670 
671   SCOPED_TRACE(m->ToString());
672   const HloInstruction* conv = nullptr;
673   ASSERT_THAT(
674       m->entry_computation()->root_instruction(),
675       GmockMatch(
676           m::GetTupleElement(
677               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget),
678               0)
679               .WithShape(F32, {1, 32, 9, 9})));
680   TF_ASSERT_OK_AND_ASSIGN(auto config,
681                           conv->backend_config<CudnnConvBackendConfig>());
682   EXPECT_EQ(config.conv_result_scale(), 42);
683 }
684 
TEST_F(CudnnFusedConvRewriterHloTest,FuseRelu)685 TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) {
686   const std::string module_str = R"(
687     HloModule Test
688 
689     ENTRY Test {
690       inputs = f32[1,17,9,9] parameter(0)
691       filters = f32[3,3,17,32] parameter(1)
692       bias = f32[32] parameter(2)
693       bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
694       zero = f32[] constant(0)
695       zeros = f32[1,32,9,9] broadcast(zero), dimensions={}
696       conv = f32[1,32,9,9] convolution(inputs, filters),
697                window={size=3x3 pad=1_1x1_1},
698                dim_labels=bf01_01io->bf01
699       sum = add(conv, bias_broadcast)
700       ROOT relu = maximum(sum, zeros)
701     })";
702   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
703 
704   GpuConvRewriter rewriter;
705   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
706   CudnnFusedConvRewriter fuser;
707   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
708 
709   SCOPED_TRACE(m->ToString());
710   const HloInstruction* conv;
711   ASSERT_THAT(
712       m->entry_computation()->root_instruction(),
713       GmockMatch(
714           m::GetTupleElement(
715               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
716                             m::Parameter(0), m::Parameter(1), m::Parameter(2)),
717               0)
718               .WithShape(F32, {1, 32, 9, 9})));
719   TF_ASSERT_OK_AND_ASSIGN(auto config,
720                           conv->backend_config<CudnnConvBackendConfig>());
721   EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
722 }
723 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseReluIfMultipleUses)724 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) {
725   const std::string module_str = R"(
726     HloModule Test
727 
728     ENTRY Test {
729       inputs = f32[1,17,9,9] parameter(0)
730       filters = f32[3,3,17,32] parameter(1)
731       bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
732       zeros = f32[1,32,9,9] broadcast(f32[] constant(0)), dimensions={}
733       conv = f32[1,32,9,9] convolution(inputs, filters),
734                window={size=3x3 pad=1_1x1_1},
735                dim_labels=bf01_01io->bf01
736       sum = add(conv, bias)
737       relu = maximum(sum, zeros)
738       not_relu = minimum(sum, zeros)
739       ROOT root = tuple(relu, not_relu)
740     })";
741   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
742 
743   GpuConvRewriter rewriter;
744   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
745   CudnnFusedConvRewriter fuser;
746   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
747 
748   SCOPED_TRACE(m->ToString());
749   const HloInstruction* conv;
750   ASSERT_THAT(
751       m->entry_computation()->root_instruction(),
752       GmockMatch(m::Tuple(
753           m::MaximumAnyOrder(
754               m::Broadcast(m::ConstantEffectiveScalar(0)),
755               m::GetTupleElement(
756                   m::CustomCall(
757                       &conv, kCudnnConvBiasActivationForwardCallTarget,
758                       m::Parameter(0), m::Parameter(1), m::Parameter(2)),
759                   0)
760                   .WithShape(F32, {1, 32, 9, 9})),
761           m::Minimum())));
762   TF_ASSERT_OK_AND_ASSIGN(auto config,
763                           conv->backend_config<CudnnConvBackendConfig>());
764   EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
765 }
766 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseAlphaIfMultipleUsers)767 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) {
768   const std::string module_str = R"(
769     HloModule Test
770 
771     ENTRY Test {
772       inputs = f32[1,17,9,9] parameter(0)
773       filters = f32[3,3,17,32] parameter(1)
774       bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
775       alpha = f32[1,32,9,9] broadcast(f32[] parameter(3)), dimensions={}
776       conv = f32[1,32,9,9] convolution(inputs, filters),
777                window={size=3x3 pad=1_1x1_1},
778                dim_labels=bf01_01io->bf01
779       sum = add(multiply(alpha, conv), bias)
780       ROOT root = tuple(conv, sum)
781     })";
782   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
783 
784   GpuConvRewriter rewriter;
785   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
786   CudnnFusedConvRewriter fuser;
787   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
788 
789   SCOPED_TRACE(m->ToString());
790   const HloInstruction* conv1;
791   const HloInstruction* conv2;
792   ASSERT_THAT(
793       m->entry_computation()->root_instruction(),
794       GmockMatch(m::Tuple(
795           m::GetTupleElement(m::CustomCall(&conv1), 0),
796           m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
797                          m::MultiplyAnyOrder(
798                              m::Broadcast(m::Parameter(3)),
799                              m::GetTupleElement(m::CustomCall(&conv2), 0))))));
800   EXPECT_EQ(conv1, conv2);
801   TF_ASSERT_OK_AND_ASSIGN(auto config,
802                           conv1->backend_config<CudnnConvBackendConfig>());
803   EXPECT_EQ(config.conv_result_scale(), 1);
804   EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
805 }
806 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseBiasIfMultipleUsers)807 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) {
808   const std::string module_str = R"(
809     HloModule Test
810 
811     ENTRY Test {
812       inputs = f32[1,17,9,9] parameter(0)
813       filters = f32[3,3,17,32] parameter(1)
814       bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
815       conv = f32[1,32,9,9] convolution(inputs, filters),
816                window={size=3x3 pad=1_1x1_1},
817                dim_labels=bf01_01io->bf01
818       ROOT root = tuple(conv, add(conv, bias))
819     })";
820   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
821 
822   GpuConvRewriter rewriter;
823   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
824   CudnnFusedConvRewriter fuser;
825   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
826 
827   SCOPED_TRACE(m->ToString());
828   const HloInstruction* conv1;
829   const HloInstruction* conv2;
830   ASSERT_THAT(
831       m->entry_computation()->root_instruction(),
832       GmockMatch(m::Tuple(
833           m::GetTupleElement(m::CustomCall(&conv1), 0),
834           m::AddAnyOrder(m::Broadcast(m::Parameter(2)),
835                          m::GetTupleElement(m::CustomCall(&conv2), 0)))));
836   EXPECT_EQ(conv1, conv2);
837   TF_ASSERT_OK_AND_ASSIGN(auto config,
838                           conv1->backend_config<CudnnConvBackendConfig>());
839   EXPECT_EQ(config.conv_result_scale(), 1);
840   EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
841 }
842 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseSideInputThroughRelu)843 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) {
844   const std::string module_str = R"(
845     HloModule Test
846 
847     ENTRY Test {
848       inputs = f32[1,17,9,9] parameter(0)
849       filters = f32[3,3,17,32] parameter(1)
850       side_input = f32[1,32,9,9] parameter(2)
851       conv = f32[1,32,9,9] convolution(inputs, filters),
852                window={size=3x3 pad=1_1x1_1},
853                dim_labels=bf01_01io->bf01
854       relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
855       ROOT root = add(relu, side_input)
856     })";
857   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
858 
859   GpuConvRewriter rewriter;
860   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
861   CudnnFusedConvRewriter fuser;
862   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
863 
864   SCOPED_TRACE(m->ToString());
865   const HloInstruction* conv;
866   ASSERT_THAT(
867       m->entry_computation()->root_instruction(),
868       GmockMatch(m::AddAnyOrder(
869           m::Parameter(2),
870           m::GetTupleElement(
871               m::CustomCall(&conv, m::Parameter(0), m::Parameter(1),
872                             m::Broadcast(m::ConstantEffectiveScalar(0))),
873               0))));
874   TF_ASSERT_OK_AND_ASSIGN(auto config,
875                           conv->backend_config<CudnnConvBackendConfig>());
876   EXPECT_EQ(config.conv_result_scale(), 1);
877   EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
878 }
879 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseBiasThroughRelu)880 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) {
881   const std::string module_str = R"(
882     HloModule Test
883 
884     ENTRY Test {
885       inputs = f32[1,17,9,9] parameter(0)
886       filters = f32[3,3,17,32] parameter(1)
887       bias = f32[1,32,9,9] broadcast(f32[32] parameter(2)), dimensions={1}
888       conv = f32[1,32,9,9] convolution(inputs, filters),
889                window={size=3x3 pad=1_1x1_1},
890                dim_labels=bf01_01io->bf01
891       relu = maximum(conv, f32[1,32,9,9] broadcast(f32[] constant(0)))
892       ROOT root = add(relu, bias)
893     })";
894   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
895 
896   GpuConvRewriter rewriter;
897   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
898   CudnnFusedConvRewriter fuser;
899   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
900 
901   SCOPED_TRACE(m->ToString());
902   const HloInstruction* conv;
903   ASSERT_THAT(m->entry_computation()->root_instruction(),
904               GmockMatch(m::AddAnyOrder(
905                   m::Broadcast(m::Parameter(2)),
906                   m::GetTupleElement(m::CustomCall(
907                       &conv, m::Parameter(0), m::Parameter(1),
908                       m::Broadcast(m::ConstantEffectiveScalar(0)))))));
909   TF_ASSERT_OK_AND_ASSIGN(auto config,
910                           conv->backend_config<CudnnConvBackendConfig>());
911   EXPECT_EQ(config.conv_result_scale(), 1);
912   EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
913 }
914 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseSideInputIfMultipleUsers)915 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) {
916   const std::string module_str = R"(
917     HloModule Test
918 
919     ENTRY Test {
920       inputs = f32[1,17,9,9] parameter(0)
921       filters = f32[3,3,17,32] parameter(1)
922       side_input = f32[1,32,9,9] parameter(2)
923       conv = f32[1,32,9,9] convolution(inputs, filters),
924                window={size=3x3 pad=1_1x1_1},
925                dim_labels=bf01_01io->bf01
926       ROOT root = tuple(conv, add(conv, side_input))
927     })";
928   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
929 
930   GpuConvRewriter rewriter;
931   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
932   CudnnFusedConvRewriter fuser;
933   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
934 
935   SCOPED_TRACE(m->ToString());
936   const HloInstruction* conv1;
937   const HloInstruction* conv2;
938   ASSERT_THAT(
939       m->entry_computation()->root_instruction(),
940       GmockMatch(m::Tuple(
941           m::GetTupleElement(m::CustomCall(&conv1), 0),
942           m::AddAnyOrder(m::Parameter(2),
943                          m::GetTupleElement(m::CustomCall(&conv2), 0)))));
944   EXPECT_EQ(conv1, conv2);
945   TF_ASSERT_OK_AND_ASSIGN(auto config,
946                           conv1->backend_config<CudnnConvBackendConfig>());
947   EXPECT_EQ(config.conv_result_scale(), 1);
948   EXPECT_EQ(config.activation_mode(), se::dnn::kNone);
949 }
950 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseConvertToF16IfMultipleUsers)951 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) {
952   const std::string module_str = R"(
953     HloModule Test
954 
955     ENTRY Test {
956       inputs = f32[1,17,9,9] convert(f16[1,17,9,9] parameter(0))
957       filters = f32[3,3,17,32] convert(f16[3,3,17,32] parameter(1))
958       conv = f32[1,32,9,9] convolution(inputs, filters),
959                window={size=3x3 pad=1_1x1_1},
960                dim_labels=bf01_01io->bf01
961       ROOT root = tuple(conv, f16[1,32,9,9] convert(conv))
962     })";
963   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
964 
965   GpuConvRewriter rewriter;
966   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
967   CudnnFusedConvRewriter fuser;
968   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
969 
970   SCOPED_TRACE(m->ToString());
971   const HloInstruction* conv1;
972   const HloInstruction* conv2;
973   ASSERT_THAT(m->entry_computation()->root_instruction(),
974               GmockMatch(m::Tuple(
975                   m::GetTupleElement(m::CustomCall(&conv1), 0),
976                   m::Convert(m::GetTupleElement(m::CustomCall(&conv2), 0)))));
977   EXPECT_EQ(conv1, conv2);
978 }
979 
TEST_F(CudnnFusedConvRewriterHloTest,DontFuseToS8IfMultipleUsers)980 TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) {
981   const std::string module_str = R"(
982     HloModule Test
983 
984     ENTRY Test {
985       inputs = f32[1,17,9,9] convert(s8[1,17,9,9] parameter(0))
986       filters = f32[3,3,17,32] convert(s8[3,3,17,32] parameter(1))
987       conv = f32[1,32,9,9] convolution(inputs, filters),
988                window={size=3x3 pad=1_1x1_1},
989                dim_labels=bf01_01io->bf01
990       conv_s8 = s8[1,32,9,9] convert(clamp(
991                   f32[1,32,9,9] broadcast(f32[] constant(-128)),
992                   conv,
993                   f32[1,32,9,9] broadcast(f32[] constant(127))))
994       ROOT root = tuple(conv, conv_s8)
995     })";
996   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
997 
998   GpuConvRewriter rewriter;
999   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1000   CudnnFusedConvRewriter fuser;
1001   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1002 
1003   SCOPED_TRACE(m->ToString());
1004   const HloInstruction* conv1;
1005   const HloInstruction* conv2;
1006   ASSERT_THAT(
1007       m->entry_computation()->root_instruction(),
1008       GmockMatch(m::Tuple(
1009           m::GetTupleElement(m::CustomCall(&conv1), 0),
1010           m::Convert(m::Clamp(m::Op(),  //
1011                               m::GetTupleElement(m::CustomCall(&conv2), 0),
1012                               m::Op())))));
1013   EXPECT_EQ(conv1, conv2);
1014 }
1015 
TEST_F(CudnnFusedConvRewriterHloTest,FuseBias)1016 TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) {
1017   const std::string module_str = R"(
1018     HloModule Test
1019 
1020     ENTRY Test {
1021       inputs = f32[1,17,9,9] parameter(0)
1022       filters = f32[3,3,17,32] parameter(1)
1023       bias = f32[32] parameter(2)
1024       bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
1025       conv = f32[1,32,9,9] convolution(inputs, filters),
1026                window={size=3x3 pad=1_1x1_1},
1027                dim_labels=bf01_01io->bf01
1028       ROOT root = add(conv, bias_broadcast)
1029     })";
1030   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1031 
1032   GpuConvRewriter rewriter;
1033   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1034   CudnnFusedConvRewriter fuser;
1035   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1036 
1037   SCOPED_TRACE(m->ToString());
1038   ASSERT_THAT(
1039       m->entry_computation()->root_instruction(),
1040       GmockMatch(
1041           m::GetTupleElement(
1042               m::CustomCall(kCudnnConvBiasActivationForwardCallTarget,
1043                             m::Parameter(0), m::Parameter(1), m::Parameter(2)),
1044               0)
1045               .WithShape(F32, {1, 32, 9, 9})));
1046 }
1047 
TEST_F(CudnnFusedConvRewriterHloTest,FuseSideInput)1048 TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) {
1049   const std::string module_str = R"(
1050     HloModule Test
1051 
1052     ENTRY Test {
1053       inputs = f32[1,17,9,9] parameter(0)
1054       filters = f32[3,3,17,32] parameter(1)
1055       side_input = f32[1,32,9,9] parameter(2)
1056       conv = f32[1,32,9,9] convolution(inputs, filters),
1057                window={size=3x3 pad=1_1x1_1},
1058                dim_labels=bf01_01io->bf01
1059       ROOT root = add(conv, side_input)
1060     })";
1061   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1062 
1063   GpuConvRewriter rewriter;
1064   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1065   CudnnFusedConvRewriter fuser;
1066   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1067 
1068   SCOPED_TRACE(m->ToString());
1069   const HloInstruction* conv;
1070   ASSERT_THAT(
1071       m->entry_computation()->root_instruction(),
1072       GmockMatch(
1073           m::GetTupleElement(
1074               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1075                             m::Parameter(0), m::Parameter(1),
1076                             m::Broadcast(m::ConstantEffectiveScalar(0))
1077                                 .WithShape(F32, {32}),
1078                             m::Parameter(2)),
1079               0)
1080               .WithShape(F32, {1, 32, 9, 9})));
1081   TF_ASSERT_OK_AND_ASSIGN(auto config,
1082                           conv->backend_config<CudnnConvBackendConfig>());
1083   EXPECT_EQ(config.side_input_scale(), 1);
1084 }
1085 
TEST_F(CudnnFusedConvRewriterHloTest,FuseScaledSideInput)1086 TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) {
1087   const std::string module_str = R"(
1088     HloModule Test
1089 
1090     ENTRY Test {
1091       inputs = f32[1,17,9,9] parameter(0)
1092       filters = f32[3,3,17,32] parameter(1)
1093       side_input = f32[1,32,9,9] parameter(2)
1094       side_input_scale = f32[] constant(42)
1095       side_input_scale_broadcast = f32[1,32,9,9] broadcast(side_input_scale), dimensions={}
1096       side_input_product = multiply(side_input, side_input_scale_broadcast)
1097       conv = f32[1,32,9,9] convolution(inputs, filters),
1098                window={size=3x3 pad=1_1x1_1},
1099                dim_labels=bf01_01io->bf01
1100       ROOT root = add(conv, side_input_product)
1101     })";
1102   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1103 
1104   GpuConvRewriter rewriter;
1105   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1106   CudnnFusedConvRewriter fuser;
1107   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1108 
1109   SCOPED_TRACE(m->ToString());
1110   const HloInstruction* conv;
1111   ASSERT_THAT(
1112       m->entry_computation()->root_instruction(),
1113       GmockMatch(
1114           m::GetTupleElement(
1115               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1116                             m::Parameter(0), m::Parameter(1),
1117                             m::Broadcast(m::ConstantEffectiveScalar(0))
1118                                 .WithShape(F32, {32}),
1119                             m::Parameter(2)),
1120               0)
1121               .WithShape(F32, {1, 32, 9, 9})));
1122   TF_ASSERT_OK_AND_ASSIGN(auto config,
1123                           conv->backend_config<CudnnConvBackendConfig>());
1124   EXPECT_EQ(config.side_input_scale(), 42);
1125 }
1126 
TEST_F(CudnnFusedConvRewriterHloTest,FuseBiasAndSideInput)1127 TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) {
1128   const std::string module_str = R"(
1129     HloModule Test
1130 
1131     ENTRY Test {
1132       inputs = f32[1,17,9,9] parameter(0)
1133       filters = f32[3,3,17,32] parameter(1)
1134       bias = f32[32] parameter(2)
1135       side_input = f32[1,32,9,9] parameter(3)
1136       bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
1137       conv = f32[1,32,9,9] convolution(inputs, filters),
1138                window={size=3x3 pad=1_1x1_1},
1139                dim_labels=bf01_01io->bf01
1140       sum = add(conv, side_input)
1141       ROOT sum2 = add(sum, bias_broadcast)
1142     })";
1143   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1144 
1145   GpuConvRewriter rewriter;
1146   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1147   CudnnFusedConvRewriter fuser;
1148   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1149 
1150   SCOPED_TRACE(m->ToString());
1151   const HloInstruction* conv;
1152   ASSERT_THAT(
1153       m->entry_computation()->root_instruction(),
1154       GmockMatch(
1155           m::GetTupleElement(
1156               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1157                             m::Parameter(0), m::Parameter(1), m::Parameter(2),
1158                             m::Parameter(3)),
1159               0)
1160               .WithShape(F32, {1, 32, 9, 9})));
1161   TF_ASSERT_OK_AND_ASSIGN(auto config,
1162                           conv->backend_config<CudnnConvBackendConfig>());
1163   EXPECT_EQ(config.side_input_scale(), 1);
1164 }
1165 
TEST_F(CudnnFusedConvRewriterHloTest,EffectiveScalarBias)1166 TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) {
1167   const std::string module_str = R"(
1168     HloModule Test
1169 
1170     ENTRY Test {
1171       inputs = f32[1,17,9,9] parameter(0)
1172       filters = f32[3,3,17,32] parameter(1)
1173       bias = f32[1,32,9,9] broadcast(f32[] parameter(2)), dimensions={}
1174       conv = f32[1,32,9,9] convolution(inputs, filters),
1175                window={size=3x3 pad=1_1x1_1},
1176                dim_labels=bf01_01io->bf01
1177       ROOT root = add(conv, bias)
1178     })";
1179   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1180 
1181   GpuConvRewriter rewriter;
1182   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1183   CudnnFusedConvRewriter fuser;
1184   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1185 
1186   SCOPED_TRACE(m->ToString());
1187   const HloInstruction* conv;
1188   ASSERT_THAT(
1189       m->entry_computation()->root_instruction(),
1190       GmockMatch(
1191           m::GetTupleElement(
1192               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1193                             m::Parameter(0), m::Parameter(1),
1194                             m::Broadcast(m::Parameter(2)).WithShape(F32, {32})),
1195               0)
1196               .WithShape(F32, {1, 32, 9, 9})));
1197 }
1198 
TEST_F(CudnnFusedConvRewriterHloTest,StrengthReduceF32ToF16)1199 TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) {
1200   const std::string module_str = R"(
1201     HloModule Test
1202 
1203     ENTRY Test {
1204       inputs = f16[1,17,9,9] parameter(0)
1205       filters = f16[3,3,17,32] parameter(1)
1206       bias = f16[32] parameter(2)
1207       side_input = f16[1,32,9,9] parameter(3)
1208 
1209       inputs_f32 = f32[1,17,9,9] convert(inputs)
1210       filters_f32 = f32[3,3,17,32] convert(filters)
1211       bias_f32 = f32[32] convert(bias)
1212       bias_broadcast = f32[1,32,9,9] broadcast(bias_f32), dimensions={1}
1213       side_input_f32 = f32[1,32,9,9] convert(side_input)
1214       conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
1215                window={size=3x3 pad=1_1x1_1},
1216                dim_labels=bf01_01io->bf01
1217       sum = add(conv, side_input_f32)
1218       sum2 = add(sum, bias_broadcast)
1219       ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
1220     })";
1221   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1222 
1223   GpuConvRewriter rewriter;
1224   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1225   CudnnFusedConvRewriter fuser;
1226   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1227 
1228   // Simplify new `convert`'s that may be added to the graph.
1229   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1230   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1231 
1232   SCOPED_TRACE(m->ToString());
1233   const HloInstruction* conv;
1234   ASSERT_THAT(
1235       m->entry_computation()->root_instruction(),
1236       GmockMatch(
1237           m::GetTupleElement(
1238               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1239                             m::Parameter(0), m::Parameter(1), m::Parameter(2),
1240                             m::Parameter(3)),
1241               0)
1242               .WithShape(F16, {1, 32, 9, 9})));
1243   TF_ASSERT_OK_AND_ASSIGN(auto config,
1244                           conv->backend_config<CudnnConvBackendConfig>());
1245   EXPECT_EQ(config.side_input_scale(), 1);
1246 }
1247 
1248 // We should be able to lower this to an f16 convolution even though the
1249 // f16-ness of the inputs is hidden behind broadcast/transpose/reshape.
TEST_F(CudnnFusedConvRewriterHloTest,BroadcastReshapeTransposeAfterConvert)1250 TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) {
1251   const std::string module_str = R"(
1252     HloModule Test
1253 
1254     ENTRY Test {
1255       inputs = f32[1,17,9,9] reshape(f32[1377] convert(f16[1377] parameter(0)))
1256       filters = f32[3,3,17,32] transpose(f32[17,32,3,3] convert(f16[17,32,3,3] parameter(1))), dimensions={2,3,0,1}
1257       bias = f16[1,32,9,9] broadcast(f16[32] parameter(2)), dimensions={1}
1258       side_input = f16[1,32,9,9] reshape(f16[2592] parameter(3))
1259 
1260       conv_f32 = f32[1,32,9,9] convolution(inputs, filters),
1261                  window={size=3x3 pad=1_1x1_1},
1262                  dim_labels=bf01_01io->bf01
1263       conv_f16 = f16[1,32,9,9] convert(conv_f32)
1264       ROOT root = f16[1,32,9,9] add(add(conv_f16, side_input), bias)
1265     })";
1266   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1267 
1268   GpuConvRewriter rewriter;
1269   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1270   CudnnFusedConvRewriter fuser;
1271   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1272 
1273   // Simplify new `convert`'s that may be added to the graph.
1274   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1275   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1276 
1277   SCOPED_TRACE(m->ToString());
1278   const HloInstruction* conv;
1279   ASSERT_THAT(
1280       m->entry_computation()->root_instruction(),
1281       GmockMatch(m::GetTupleElement(
1282                      m::CustomCall(
1283                          &conv, kCudnnConvBiasActivationForwardCallTarget,
1284                          m::Convert(m::Reshape(m::Convert(m::Parameter(0))))
1285                              .WithElementType(F16),
1286                          m::Convert(m::Transpose(m::Convert(m::Parameter(1))))
1287                              .WithElementType(F16),
1288                          m::Parameter(2), m::Reshape(m::Parameter(3))),
1289                      0)
1290                      .WithShape(F16, {1, 32, 9, 9})));
1291   TF_ASSERT_OK_AND_ASSIGN(auto config,
1292                           conv->backend_config<CudnnConvBackendConfig>());
1293   EXPECT_EQ(config.side_input_scale(), 1);
1294 }
1295 
TEST_F(CudnnFusedConvRewriterHloTest,NoStrengthReduceF32ToF16IfBiasIsF32)1296 TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) {
1297   const std::string module_str = R"(
1298     HloModule Test
1299 
1300     ENTRY Test {
1301       inputs = f16[1,17,9,9] parameter(0)
1302       filters = f16[3,3,17,32] parameter(1)
1303       bias = f32[32] parameter(2)
1304       side_input = f16[1,32,9,9] parameter(3)
1305 
1306       inputs_f32 = f32[1,17,9,9] convert(inputs)
1307       filters_f32 = f32[3,3,17,32] convert(filters)
1308       bias_broadcast = f32[1,32,9,9] broadcast(bias), dimensions={1}
1309       side_input_f32 = f32[1,32,9,9] convert(side_input)
1310       conv = f32[1,32,9,9] convolution(inputs_f32, filters_f32),
1311                window={size=3x3 pad=1_1x1_1},
1312                dim_labels=bf01_01io->bf01
1313       sum = add(conv, side_input_f32)
1314       sum2 = add(sum, bias_broadcast)
1315       ROOT conv_f16 = f16[1,32,9,9] convert(sum2)
1316     })";
1317   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1318 
1319   GpuConvRewriter rewriter;
1320   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1321   CudnnFusedConvRewriter fuser;
1322   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1323 
1324   // Simplify new `convert`'s that may be added to the graph.
1325   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1326   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1327 
1328   SCOPED_TRACE(m->ToString());
1329   const HloInstruction* conv;
1330   // fp16 convs only support fp16 biases.  Because bias is fp32, it doesn't get
1331   // fused in, and we get an fp32 conv.
1332   ASSERT_THAT(
1333       m->entry_computation()->root_instruction(),
1334       GmockMatch(
1335           m::Convert(m::GetTupleElement(
1336                          m::CustomCall(
1337                              &conv, kCudnnConvBiasActivationForwardCallTarget,
1338                              m::Convert(m::Parameter(0)).WithElementType(F32),
1339                              m::Convert(m::Parameter(1)).WithElementType(F32),
1340                              m::Parameter(2),
1341                              m::Convert(m::Parameter(3)).WithElementType(F32)),
1342                          0))
1343               .WithShape(F16, {1, 32, 9, 9})));
1344   TF_ASSERT_OK_AND_ASSIGN(auto config,
1345                           conv->backend_config<CudnnConvBackendConfig>());
1346   EXPECT_EQ(config.side_input_scale(), 1);
1347 }
1348 
TEST_F(CudnnFusedConvRewriterHloTest,F32Constants)1349 TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) {
1350   const std::string module_str = R"(
1351     HloModule Test
1352 
1353     ENTRY Test {
1354       inputs = f16[1,2,2,2] parameter(0)
1355       filters_f32 = f32[1,1,2,2] constant({{{{1, 2},{3, 4}}}})
1356       bias = f16[2] parameter(1)
1357       bias_f32 = f32[2] convert(bias)
1358       side_input_f32 = f32[1,2,2,2] constant({{
1359         {{0.5, 0.25}, {0.125, 0.0625}},
1360         {{0.5, 0.25}, {0.125, 0.0625}}
1361       }})
1362 
1363       inputs_f32 = f32[1,2,2,2] convert(inputs)
1364       bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
1365       conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
1366                window={size=1x1}, dim_labels=bf01_01io->bf01
1367       sum = add(conv, side_input_f32)
1368       sum2 = add(sum, bias_broadcast)
1369       ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
1370     })";
1371   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1372 
1373   GpuConvRewriter rewriter;
1374   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1375   CudnnFusedConvRewriter fuser;
1376   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1377 
1378   // Simplify new `convert`'s that may be added to the graph, and fold
1379   // convert back into constants.
1380   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1381   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1382   HloConstantFolding constant_folding;
1383   TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
1384 
1385   SCOPED_TRACE(m->ToString());
1386   const HloInstruction* conv;
1387   ASSERT_THAT(
1388       m->entry_computation()->root_instruction(),
1389       GmockMatch(m::GetTupleElement(
1390                      m::CustomCall(
1391                          &conv, kCudnnConvBiasActivationForwardCallTarget,
1392                          m::Parameter(0), m::Constant().WithElementType(F16),
1393                          m::Parameter(1), m::Constant().WithElementType(F16)),
1394                      0)
1395                      .WithShape(F16, {1, 2, 2, 2})));
1396   TF_ASSERT_OK_AND_ASSIGN(auto config,
1397                           conv->backend_config<CudnnConvBackendConfig>());
1398   EXPECT_EQ(config.side_input_scale(), 1);
1399 }
1400 
TEST_F(CudnnFusedConvRewriterHloTest,F32ConstantsNotLosslesslyConvertible)1401 TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) {
1402   const std::string module_str = R"(
1403     HloModule Test
1404 
1405     ENTRY Test {
1406       inputs = f16[1,2,2,2] parameter(0)
1407       filters_f32 = f32[1,1,2,2] constant({{{{1, 2.123456789},{3, 4}}}})
1408       bias = f16[2] parameter(1)
1409       bias_f32 = f32[2] convert(bias)
1410       side_input_f32 = f32[1,2,2,2] constant({{
1411         {{0.1, 0.2}, {0.3, 0.4}},
1412         {{0.5, 0.6}, {0.7, 0.8}}
1413       }})
1414 
1415       inputs_f32 = f32[1,2,2,2] convert(inputs)
1416       bias_broadcast = f32[1,2,2,2] broadcast(bias_f32), dimensions={1}
1417       conv = f32[1,2,2,2] convolution(inputs_f32, filters_f32),
1418                window={size=1x1}, dim_labels=bf01_01io->bf01
1419       sum = add(conv, side_input_f32)
1420       sum2 = add(sum, bias_broadcast)
1421       ROOT conv_f16 = f16[1,2,2,2] convert(sum2)
1422     })";
1423   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1424 
1425   GpuConvRewriter rewriter;
1426   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1427   CudnnFusedConvRewriter fuser;
1428   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1429 
1430   // Simplify new `convert`'s that may be added to the graph, and fold
1431   // convert back into constants.
1432   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1433   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1434   HloConstantFolding constant_folding;
1435   TF_ASSERT_OK(RunHloPass(&constant_folding, m.get()).status());
1436 
1437   SCOPED_TRACE(m->ToString());
1438   const HloInstruction* conv;
1439   // This doesn't get transformed into an f16 conv because the filters param is
1440   // not losslessly expressible as f16.
1441   ASSERT_THAT(
1442       m->entry_computation()->root_instruction(),
1443       GmockMatch(
1444           m::Convert(m::GetTupleElement(
1445                          m::CustomCall(
1446                              &conv, kCudnnConvBiasActivationForwardCallTarget,
1447                              m::Convert(m::Parameter(0)).WithElementType(F32),
1448                              m::Constant().WithElementType(F32),
1449                              m::Convert(m::Parameter(1)).WithElementType(F32),
1450                              m::Constant().WithElementType(F32)),
1451                          0)
1452                          .WithShape(F32, {1, 2, 2, 2}))
1453               .WithElementType(F16)));
1454   TF_ASSERT_OK_AND_ASSIGN(auto config,
1455                           conv->backend_config<CudnnConvBackendConfig>());
1456   EXPECT_EQ(config.side_input_scale(), 1);
1457 }
1458 
TEST_F(CudnnFusedConvRewriterHloTest,FuseReluBeforeConvert)1459 TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) {
1460   const std::string module_str = R"(
1461   HloModule Test
1462 
1463   ENTRY Test {
1464     input = s8[1,17,9,9] parameter(0)
1465     filter = s8[3,3,17,32] parameter(1)
1466     inputs32 = s32[1,17,9,9] convert(input)
1467     filters32 = s32[3,3,17,32] convert(filter)
1468 
1469     conv = s32[1,32,9,9] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
1470 
1471     zero = s32[] constant(0)
1472     zeros = s32[1,32,9,9] broadcast(zero), dimensions={}
1473     relu = maximum(conv, zeros)
1474 
1475     lower = s32[] constant(-128)
1476     lowers = s32[1,32,9,9] broadcast(lower), dimensions={}
1477     upper = s32[] constant(127)
1478     uppers = s32[1,32,9,9] broadcast(upper), dimensions={}
1479 
1480     clamp = s32[1,32,9,9] clamp(lowers, relu, uppers)
1481 
1482     ROOT convert = s8[1,32,9,9] convert(clamp)
1483   })";
1484   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1485 
1486   GpuConvRewriter rewriter;
1487   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1488   CudnnFusedConvRewriter fuser;
1489   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1490 
1491   // Simplify new `convert`'s that may be added to the graph.
1492   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1493   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1494 
1495   SCOPED_TRACE(m->ToString());
1496   const HloInstruction* conv;
1497   ASSERT_THAT(
1498       m->entry_computation()->root_instruction(),
1499       GmockMatch(
1500           m::GetTupleElement(
1501               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1502                             m::Parameter(0),  //
1503                             m::Parameter(1),  //
1504                             m::Broadcast(m::ConstantEffectiveScalar(0))
1505                                 .WithShape(F32, {32})),
1506               0)
1507               .WithShape(S8, {1, 32, 9, 9})));
1508   TF_ASSERT_OK_AND_ASSIGN(auto config,
1509                           conv->backend_config<CudnnConvBackendConfig>());
1510   EXPECT_EQ(config.activation_mode(), se::dnn::kRelu);
1511 }
1512 
TEST_F(CudnnFusedConvRewriterHloTest,BiasTypeMatchesConvTypeIfFp)1513 TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) {
1514   const std::string module_str = R"(
1515   HloModule Test
1516 
1517   ENTRY Test {
1518     input = f64[1,17,9,9] parameter(0)
1519     filter = f64[3,3,17,32] parameter(1)
1520     bias = f64[1,32,9,9] broadcast(f64[32] convert(f32[32] parameter(2))), dimensions={1}
1521     conv = f64[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
1522     ROOT root = f64[1,32,9,9] add(conv, bias)
1523   })";
1524   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1525 
1526   GpuConvRewriter rewriter;
1527   TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status());
1528   CudnnFusedConvRewriter fuser;
1529   TF_ASSERT_OK(RunHloPass(&fuser, m.get()).status());
1530 
1531   // Simplify new `convert`'s that may be added to the graph.
1532   AlgebraicSimplifier algsimp(AlgebraicSimplifierOptions{});
1533   TF_ASSERT_OK(RunHloPass(&algsimp, m.get()).status());
1534 
1535   SCOPED_TRACE(m->ToString());
1536   const HloInstruction* conv;
1537   ASSERT_THAT(
1538       m->entry_computation()->root_instruction(),
1539       GmockMatch(
1540           m::GetTupleElement(
1541               m::CustomCall(&conv, kCudnnConvBiasActivationForwardCallTarget,
1542                             m::Parameter(0),  //
1543                             m::Parameter(1),  //
1544                             m::Convert(m::Parameter(2)).WithShape(F64, {32})),
1545               0)
1546               .WithShape(F64, {1, 32, 9, 9})));
1547 }
1548 
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvInt8ToInt8)1549 TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8) {
1550   // clamp(max(0, conv(x, w)+bias)); for int8_t
1551   TestClamp(
1552       // pre_hlo
1553       R"(
1554     HloModule Test
1555 
1556     ENTRY Test {
1557       zero = f32[] constant(0)
1558       zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1559 
1560       input = s8[1,3,3,64] parameter(0)
1561       filter = s8[3,3,64,64] parameter(1)
1562       bias = f32[64] parameter(2)
1563 
1564       inputs32 = s32[1,3,3,64] convert(input)
1565       filters32 = s32[3,3,64,64] convert(filter)
1566 
1567       conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1568 
1569       convfloat = f32[1,3,3,64] convert(conv)
1570       broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1571       add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
1572       relu = f32[1,3,3,64] maximum(zeros, add1)
1573 
1574       lower = f32[] constant(-128)
1575       lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1576       upper = f32[] constant(127)
1577       uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1578 
1579       clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1580 
1581       ROOT convert = s8[1,3,3,64] convert(clamp)
1582     })",
1583       // post_hlo
1584       R"(
1585       ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> s8[1,3,3,64]
1586       ; CHECK:  %cudnn-conv-bias-activation{{(\.[0-9])?}} = (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1587       ; CHECK-NEXT:  ROOT %get-tuple-element{{(\.[0-9])?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9])?}}), index=0
1588       )");
1589 }
1590 
1591 // Disabled per b/190854862 or nvbugs/3326122.
TEST_F(CudnnFusedConvRewriterTest,DISABLED_TestFusedConvInt8ToFloat)1592 TEST_F(CudnnFusedConvRewriterTest, DISABLED_TestFusedConvInt8ToFloat) {
1593   // max(0, convert<float>(conv<int32_t>(int8_x),
1594   // conv<int32_t>(int8_w))+float_bias)); int8_t to float via bias.
1595   TestClamp(
1596       // pre_hlo
1597       R"(
1598     HloModule Test
1599 
1600     ENTRY Test {
1601       zero = f32[] constant(0)
1602       zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1603 
1604       input = s8[1,3,3,64] parameter(0)
1605       filter = s8[3,3,64,64] parameter(1)
1606       bias = f32[64] parameter(2)
1607 
1608       inputs32 = s32[1,3,3,64] convert(input)
1609       filters32 = s32[3,3,64,64] convert(filter)
1610 
1611       conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1612 
1613       convfloat = f32[1,3,3,64] convert(conv)
1614       broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1615       add1 = f32[1,3,3,64] add(convfloat, broadcasted_bias)
1616       ROOT relu = f32[1,3,3,64] maximum(zeros, add1)
1617     })",
1618       // post_hlo
1619       R"(
1620       ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], bias: f32[64]) -> f32[1,3,3,64] {
1621       ; CHECK:  %custom-call{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1622       ; CHECK-NEXT:  ROOT %get-tuple-element{{(\.[0-9])?}} = f32[1,3,3,64]{3,2,1,0} get-tuple-element(%custom-call{{(\.[0-9])?}}), index=0
1623       )");
1624 }
1625 
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8)1626 TEST_F(CudnnFusedConvRewriterTest,
1627        TestFusedConvWithScaledInt8SideInputBiasInt8ToInt8) {
1628   // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
1629   // convert<int32_t>(int8_side_input) + bias)); for int8_t
1630   TestClamp(
1631       // pre_hlo
1632       R"(
1633     HloModule Test
1634 
1635     ENTRY Test {
1636       zero = f32[] constant(0)
1637       zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1638       alpha_conv_scalar = f32[] constant(0.999994934)
1639       alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
1640       alpha_side_input_scalar = f32[] constant(0.899994934)
1641       alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
1642 
1643       input = s8[1,3,3,64] parameter(0)
1644       filter = s8[3,3,64,64] parameter(1)
1645       side_input = s8[1,3,3,64] parameter(2)
1646       bias = f32[64] parameter(3)
1647 
1648       inputs32 = s32[1,3,3,64] convert(input)
1649       filters32 = s32[3,3,64,64] convert(filter)
1650       side_input_f32 = f32[1,3,3,64] convert(side_input)
1651 
1652       conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1653 
1654       convfloat = f32[1,3,3,64] convert(conv)
1655       scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
1656       scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
1657       broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1658       add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
1659       add2 = f32[1,3,3,64] add(add1, scaled_side_input)
1660       relu = f32[1,3,3,64] maximum(zeros, add2)
1661 
1662       lower = f32[] constant(-128)
1663       lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1664       upper = f32[] constant(127)
1665       uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1666 
1667       clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1668 
1669       ROOT convert = s8[1,3,3,64] convert(clamp)
1670     })",
1671       // post_hlo
1672       R"(
1673       ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] {
1674       ; CHECK:  %cudnn-conv-bias-activation{{(\.[0-9]+)?}} =
1675       ; CHECK-SAME: (s8[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0})
1676       ; CHECK-SAME: custom-call(%input, %copy{{(\.[0-9]+)?}}, %bias, %side_input),
1677       ; CHECK-SAME: window={size=3x3 pad=1_1x1_1},
1678       ; CHECK-SAME: dim_labels=b01f_01io->b01f,
1679       ; CHECK-SAME: custom_call_target="__cudnn$convBiasActivationForward",
1680       ; CHECK-NEXT: ROOT %get-tuple-element{{(\.[0-9]+)?}} = s8[1,3,3,64]{3,2,1,0} get-tuple-element(%cudnn-conv-bias-activation{{(\.[0-9]+)?}}), index=0
1681       )");
1682 }
1683 
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8)1684 TEST_F(CudnnFusedConvRewriterTest,
1685        TestFusedConvWithScaledFloatSideInputBiasInt8ToInt8) {
1686   // From:
1687   // convert<int8_t>(clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
1688   // float_side_input + bias))); To: convert<int8_t>(clamp(conv(int8_x, int8_w,
1689   // float_alpha_side, float_side_input, float_bias)));
1690   TestClamp(
1691       // pre_hlo
1692       R"(
1693     HloModule Test
1694 
1695     ENTRY Test {
1696       zero = f32[] constant(0)
1697       zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1698       alpha_conv_scalar = f32[] constant(0.999994934)
1699       alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
1700       alpha_side_input_scalar = f32[] constant(0.899994934)
1701       alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
1702 
1703       input = s8[1,3,3,64] parameter(0)
1704       filter = s8[3,3,64,64] parameter(1)
1705       side_input = f32[1,3,3,64] parameter(2)
1706       bias = f32[64] parameter(3)
1707 
1708       inputs32 = s32[1,3,3,64] convert(input)
1709       filters32 = s32[3,3,64,64] convert(filter)
1710 
1711       conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1712 
1713       convfloat = f32[1,3,3,64] convert(conv)
1714       scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
1715       scaled_side_input = f32[1,3,3,64] multiply(side_input, alpha_side_input)
1716       broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1717       add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
1718       add2 = f32[1,3,3,64] add(add1, scaled_side_input)
1719       relu = f32[1,3,3,64] maximum(zeros, add2)
1720 
1721       lower = f32[] constant(-128)
1722       lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1723       upper = f32[] constant(127)
1724       uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1725 
1726       clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1727 
1728       ROOT convert = s8[1,3,3,64] convert(clamp)
1729     })",
1730       //  post_hlo
1731       R"(
1732       ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: f32[1,3,3,64], bias: f32[64]) -> s8[1,3,3,64] {
1733       ; CHECK:  %cudnn-conv-bias-activation{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]+}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1734       ; CHECK:  ROOT %fusion = s8[1,3,3,64]{3,2,1,0} fusion(%get-tuple-element{{(\.[0-9])?}}), kind=kLoop, calls=%fused_computation
1735       )");
1736 }
1737 
TEST_F(CudnnFusedConvRewriterTest,TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat)1738 TEST_F(CudnnFusedConvRewriterTest,
1739        TestFusedConvWithScaledInt8SideInputBiasInt8ToFloat) {
1740   // From:
1741   // clamp(max(0, alpha_conv * conv(x, w) + alpha_side *
1742   // convert<float>(int8_side_input) + bias)); To: clamp(conv(int8_x, int8_w,
1743   // float_alpha_side, convert<float>(int8_side_input), float_bias));
1744   TestClamp(
1745       // pre_hlo
1746       R"(
1747     HloModule Test
1748 
1749     ENTRY Test {
1750       zero = f32[] constant(0)
1751       zeros = f32[1,3,3,64] broadcast(zero), dimensions={}
1752       alpha_conv_scalar = f32[] constant(0.999994934)
1753       alpha_conv = f32[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
1754       alpha_side_input_scalar = f32[] constant(0.899994934)
1755       alpha_side_input = f32[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
1756 
1757       input = s8[1,3,3,64] parameter(0)
1758       filter = s8[3,3,64,64] parameter(1)
1759       side_input = s8[1,3,3,64] parameter(2)
1760       bias = f32[64] parameter(3)
1761 
1762       inputs32 = s32[1,3,3,64] convert(input)
1763       filters32 = s32[3,3,64,64] convert(filter)
1764       side_input_f32 = f32[1,3,3,64] convert(side_input)
1765 
1766       conv = s32[1,3,3,64] convolution(inputs32, filters32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
1767 
1768       convfloat = f32[1,3,3,64] convert(conv)
1769       scaled_conv = f32[1,3,3,64] multiply(convfloat, alpha_conv)
1770       scaled_side_input = f32[1,3,3,64] multiply(side_input_f32, alpha_side_input)
1771       broadcasted_bias = f32[1,3,3,64] broadcast(bias), dimensions={3}
1772       add1 = f32[1,3,3,64] add(scaled_conv, broadcasted_bias)
1773       add2 = f32[1,3,3,64] add(add1, scaled_side_input)
1774       relu = f32[1,3,3,64] maximum(zeros, add2)
1775 
1776       lower = f32[] constant(-128)
1777       lowers = f32[1,3,3,64] broadcast(lower), dimensions={}
1778       upper = f32[] constant(127)
1779       uppers = f32[1,3,3,64] broadcast(upper), dimensions={}
1780 
1781       ROOT clamp = f32[1,3,3,64] clamp(lowers, relu, uppers)
1782     })",
1783       // post_hlo
1784       R"(
1785       ; CHECK-LABEL: ENTRY %Test (input: s8[1,3,3,64], filter: s8[3,3,64,64], side_input: s8[1,3,3,64], bias: f32[64]) -> f32[1,3,3,64] {
1786       ; CHECK:  %side_input_f32 = f32[1,3,3,64]{3,2,1,0} convert(%side_input)
1787       ; CHECK:  %cudnn-conv-bias-activation{{(\.[0-9])?}} = (f32[1,3,3,64]{3,2,1,0}, u8[{{[0-9]*}}]{0}) custom-call(%input, %copy{{(\.[0-9])?}}, %bias, %side_input_f32), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convBiasActivationForward", backend_config=
1788       ; CHECK:  ROOT %fusion = f32[1,3,3,64]{3,2,1,0} fusion(%get-tuple-element{{(\.[0-9])?}}), kind=kLoop, calls=%fused_computation
1789       )");
1790 }
1791 
TEST_F(CudnnFusedConvRewriterTest,TestConvInt8ToInt8NoClamp)1792 TEST_F(CudnnFusedConvRewriterTest, TestConvInt8ToInt8NoClamp) {
1793   // Check that integer convolution without clamp to int8_t is not allowed.
1794   // convert<int8_t>(custom_call<int32_t>(int32_x, int32_w,
1795   // cudnnConvolutionForward))
1796   const std::string module_str = absl::StrFormat(R"(
1797     HloModule Test
1798 
1799     ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
1800       zero = s8[] constant(0)
1801       zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
1802       input = s8[1,17,9,9]{3,2,1,0} parameter(0)
1803       filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
1804       custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
1805       get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
1806       convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
1807       ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
1808     })");
1809   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1810 
1811   ASSERT_FALSE(CudnnFusedConvRewriter().Run(m.get()).ok());
1812 }
1813 
1814 TEST_F(CudnnFusedConvRewriterTest, TestFusedConvInt8ToInt8NoClamp) {
1815   // Although bias and so on are fused with forward convolution,
1816   // it is still not allowed if the output is not clampped/converted to int8_t
1817   // max(0, alpha_conv * conv(x, w) + alpha_side * side_input + bias); for
1818   // int8_t
1819 
1820   const std::string module_str = absl::StrFormat(R"(
1821     HloModule Test
1822 
1823     ENTRY Test (input: s8[1,17,9,9], filter: s8[3,3,17,32]) -> s8[1,32,9,9] {
1824       zero = s8[] constant(0)
1825       zeros = s8[1,32,9,9]{3,2,1,0} broadcast(s8[] zero), dimensions={}
1826       input = s8[1,17,9,9]{3,2,1,0} parameter(0)
1827       filter = s8[3,3,17,32]{3,2,1,0} parameter(1)
1828       custom-call = (s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call(s8[1,17,9,9]{3,2,1,0} input, s8[3,3,17,32]{3,2,1,0} filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", backend_config="{\"convResultScale\":1}"
1829       get-tuple-element = s32[1,32,9,9]{3,2,1,0} get-tuple-element((s32[1,32,9,9]{3,2,1,0}, u8[0]{0}) custom-call), index=0
1830       convert = s8[1,32,9,9]{3,2,1,0} convert(s32[1,32,9,9]{3,2,1,0} get-tuple-element)
1831       ROOT relu = s8[1,32,9,9]{3,2,1,0} maximum(s8[1,32,9,9]{3,2,1,0} zeros, s8[1,32,9,9]{3,2,1,0} convert)
1832     })");
1833   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str));
1834 
1835   ASSERT_FALSE(CudnnFusedConvRewriter().Run(m.get()).ok());
1836 }
1837 
1838 }  // namespace
1839 }  // namespace gpu
1840 }  // namespace xla
1841