• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/fusion_bitcast_lift.h"
17 
18 #include <vector>
19 
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/service/hlo_dce.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/tests/filecheck.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 
26 // TODO(b/210165681): The tests in this file are fragile to HLO op names.
27 
28 namespace xla {
29 namespace gpu {
30 namespace {
31 
32 class FusionBitcastLiftTest : public HloTestBase {};
33 
34 // Tests that we lift bitcast outside the fusion.
35 //
36 // This test MultiOutputFusion, multiple consecutive lift, bitcast
37 // with multiple users and bitcast that are used many time by the same
38 // user. This is a real kernel from Efficient Net, but with smaller
39 // shape to speed up tests.
40 //
41 // Input graph:
42 // Fusion 4d input, 2 1d output
43 //
44 // After optimization, the graph is:
45 // Bitcast 4d -> 2d
46 //   |
47 // Fusion 2d input, 2x1d outputs.
TEST_F(FusionBitcastLiftTest,NoBroadcastTest)48 TEST_F(FusionBitcastLiftTest, NoBroadcastTest) {
49   const char* hlo_text = R"(
50 HloModule mod
51 
52 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
53   %scalar_lhs.1 = f32[] parameter(0)
54   %scalar_rhs.1 = f32[] parameter(1)
55   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
56 }
57 
58 %fused_computation.4d (param_0: f16[2,14,14,672]) -> (f32[672], f32[672]) {
59   %param_0 = f16[2,14,14,672] parameter(0)
60   %convert = f32[2,14,14,672] convert(%param_0)
61   %bitcast.1 = f32[392,672] bitcast(%convert)
62   %constant_0 = f32[] constant(0)
63   %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
64   %multiply = f32[2,14,14,672] multiply(%convert, %convert)
65   %bitcast.2 = f32[392,672] bitcast(%multiply)
66   %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
67   ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
68 }
69 
70 ENTRY %main {
71   %param_0 = f16[2,14,14,672] parameter(0)
72   ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0), kind=kInput, calls=%fused_computation.4d
73 }
74 )";
75   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
76   EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
77   // Remove the old fusion not used anymore.
78   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
79 
80   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
81 
82   StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
83                                                  R"(
84 ; CHECK-LABEL: %fused_computation
85 ; CHECK:         f16[392,672]{1,0} parameter(0)
86 ; CHECK-NOT:     parameter
87 ; CHECK-NOT:     bitcast
88 ; CHECK-LABEL: ENTRY %main
89 ; CHECK-NEXT:    f16[2,14,14,672]{3,2,1,0} parameter(0)
90 ; CHECK-NEXT:    bitcast(
91 ; CHECK-NEXT:    fusion(
92       )");
93   EXPECT_TRUE(filecheck_result.status().ok());
94   EXPECT_TRUE(filecheck_result.ValueOrDie());
95 }
96 
97 // Tests that we lift bitcast outside the fusion when scalar broadcasting are
98 // present.
99 //
100 // Input graph:
101 // Fusion 1x4d and 1x0d inputs, 2 1d output
102 //
103 // After optimization, the graph is:
104 // Bitcast 4d -> 2d
105 //   |
106 // Fusion 1x2d and 1x0d inputs, 2 1d output
107 //   Inside the fusion, there is a bitcast left after the broadcast.
TEST_F(FusionBitcastLiftTest,ScalarBroadcastTest)108 TEST_F(FusionBitcastLiftTest, ScalarBroadcastTest) {
109   const char* hlo_text = R"(
110 HloModule mod
111 
112 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
113   %scalar_lhs.1 = f32[] parameter(0)
114   %scalar_rhs.1 = f32[] parameter(1)
115   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
116 }
117 
118 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[]) -> (f32[672], f32[672]) {
119   %param_0 = f16[2,14,14,672] parameter(0)
120   %convert = f32[2,14,14,672] convert(%param_0)
121   %bitcast.1 = f32[392,672] bitcast(%convert)
122   %constant_0 = f32[] constant(0)
123   %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
124   %param_1 = f32[] parameter(1)
125   %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={}
126   %multiply = f32[2,14,14,672] multiply(%convert, %broadcast)
127   %bitcast.2 = f32[392,672] bitcast(%multiply)
128   %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
129   ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
130 }
131 
132 ENTRY %main {
133   %param_0 = f16[2,14,14,672] parameter(0)
134   %param_1 = f32[] parameter(1)
135   ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1), kind=kInput, calls=%fused_computation.4d
136 }
137 )";
138   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
139   EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
140   // Remove the old fusion not used anymore.
141   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
142 
143   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
144 
145   StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
146                                                  R"(
147 ; CHECK-LABEL: %fused_computation
148 ; CHECK:         f16[392,672]{1,0} parameter(0)
149 ; CHECK:         f32[] parameter(1)
150 ; CHECK-NOT:     parameter
151 ; CHECK-NOT:     bitcast
152 ; CHECK:         %broadcast.1 =
153 ; CHECK-NEXT:    bitcast(f32[2,14,14,672]{3,2,1,0} %broadcast.1)
154 ; CHECK-NOT:     bitcast(
155 ; CHECK-LABEL: ENTRY %main
156 ; CHECK-NEXT:    f16[2,14,14,672]{3,2,1,0} parameter(0)
157 ; CHECK-NEXT:    bitcast(
158 ; CHECK-NEXT:    %param_1.1 = f32[] parameter(1)
159 ; CHECK-NEXT:    fusion(
160       )");
161   EXPECT_TRUE(filecheck_result.status().ok());
162   EXPECT_TRUE(filecheck_result.ValueOrDie());
163 }
164 
TEST_F(FusionBitcastLiftTest,RowBroadcastTest)165 TEST_F(FusionBitcastLiftTest, RowBroadcastTest) {
166   const char* hlo_text = R"(
167 HloModule mod
168 
169 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
170   %scalar_lhs.1 = f32[] parameter(0)
171   %scalar_rhs.1 = f32[] parameter(1)
172   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
173 }
174 
175 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[672]) -> (f32[672], f32[672]) {
176   %param_0 = f16[2,14,14,672] parameter(0)
177   %convert = f32[2,14,14,672] convert(%param_0)
178   %bitcast.1 = f32[392,672] bitcast(%convert)
179   %constant_0 = f32[] constant(0)
180   %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
181   %param_1 = f32[672] parameter(1)
182   %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={3}
183   %multiply = f32[2,14,14,672] multiply(%convert, %broadcast)
184   %bitcast.2 = f32[392,672] bitcast(%multiply)
185   %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
186   ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
187 }
188 
189 ENTRY %main {
190   %param_0 = f16[2,14,14,672] parameter(0)
191   %param_1 = f32[672] parameter(1)
192   ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1), kind=kInput, calls=%fused_computation.4d
193 }
194 )";
195   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
196   EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
197   // Remove the old fusion not used anymore.
198   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
199 
200   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
201 
202   StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
203                                                  R"(
204 ; CHECK-LABEL: %fused_computation
205 ; CHECK:         f16[392,672]{1,0} parameter(0)
206 ; CHECK:         f32[672]{0} parameter(1)
207 ; CHECK-NOT:     parameter
208 ; CHECK-NOT:     bitcast
209 ; CHECK:         %broadcast.1
210 ; CHECK:         bitcast(f32[2,14,14,672]{3,2,1,0} %broadcast.1)
211 ; CHECK-NOT:     bitcast(
212 ; CHECK-LABEL: ENTRY %main
213 ; CHECK-NEXT:    f16[2,14,14,672]{3,2,1,0} parameter(0)
214 ; CHECK-NEXT:    bitcast(
215 ; CHECK-NEXT:    %param_1.1 = f32[672]{0} parameter(1)
216 ; CHECK-NEXT:    fusion(
217       )");
218   EXPECT_TRUE(filecheck_result.status().ok());
219   EXPECT_TRUE(filecheck_result.ValueOrDie());
220 }
221 
TEST_F(FusionBitcastLiftTest,ScalarAndRowBroadcastTest)222 TEST_F(FusionBitcastLiftTest, ScalarAndRowBroadcastTest) {
223   const char* hlo_text = R"(
224 HloModule mod
225 
226 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
227   %scalar_lhs.1 = f32[] parameter(0)
228   %scalar_rhs.1 = f32[] parameter(1)
229   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
230 }
231 
232 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[672], param_2: f32[]) -> (f32[672], f32[672]) {
233   %param_0 = f16[2,14,14,672] parameter(0)
234   %convert = f32[2,14,14,672] convert(%param_0)
235   %bitcast.1 = f32[392,672] bitcast(%convert)
236   %constant_0 = f32[] constant(0)
237   %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
238   %param_1 = f32[672] parameter(1)
239   %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={3}
240   %multiply.1 = f32[2,14,14,672] multiply(%convert, %broadcast)
241   %param_2 = f32[] parameter(2)
242   %broadcast.1 = f32[2,14,14,672] broadcast(%param_2), dimensions={}
243   %multiply.2 = f32[2,14,14,672] multiply(%broadcast.1, %multiply.1)
244   %bitcast.2 = f32[392,672] bitcast(%multiply.2)
245   %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
246   ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
247 }
248 
249 ENTRY %main {
250   %param_0 = f16[2,14,14,672] parameter(0)
251   %param_1 = f32[672] parameter(1)
252   %param_2 = f32[] parameter(2)
253   ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2), kind=kInput, calls=%fused_computation.4d
254 }
255 )";
256 
257   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
258   EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
259   // Remove the old fusion not used anymore.
260   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
261 
262   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
263 
264   StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
265                                                  R"(
266 ; CHECK-LABEL: %fused_computation
267 ; CHECK-NOT:     bitcast
268 ; CHECK:         f32[2,14,14,672]{3,2,1,0} broadcast(
269 ; CHECK-NEXT:    f32[392,672]{1,0} bitcast(
270 ; CHECK-NOT:     bitcast
271 ; CHECK:         f32[2,14,14,672]{3,2,1,0} broadcast(
272 ; CHECK-NEXT:    f32[392,672]{1,0} bitcast(
273 ; CHECK-NOT:     bitcast(
274 ; CHECK-LABEL: ENTRY %main
275 ; CHECK-NOT:     bitcast(
276 ; CHECK:    bitcast(f16[2,14,14,672]{3,2,1,0} %param_0
277 ; CHECK-NOT:     bitcast(
278       )");
279   EXPECT_TRUE(filecheck_result.status().ok());
280   EXPECT_TRUE(filecheck_result.ValueOrDie());
281 }
282 
283 // To trigger the bitcast same pattern check.
TEST_F(FusionBitcastLiftTest,StrangeBitcastBroadcastTest)284 TEST_F(FusionBitcastLiftTest, StrangeBitcastBroadcastTest) {
285   const char* hlo_text = R"(
286 HloModule mod
287 
288 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
289   %scalar_lhs.1 = f32[] parameter(0)
290   %scalar_rhs.1 = f32[] parameter(1)
291   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
292 }
293 
294 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[672], param_2: f32[672]) -> (f32[672], f32[672]) {
295   %param_0 = f16[2,14,14,672] parameter(0)
296   %convert = f32[2,14,14,672] convert(%param_0)
297   %bitcast.1 = f32[392,672] bitcast(%convert)
298   %constant_0 = f32[] constant(0)
299   %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
300   %param_1 = f32[672] parameter(1)
301   %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={3}
302   %multiply.1 = f32[2,14,14,672] multiply(%convert, %broadcast)
303   %param_2 = f32[672] parameter(2)
304   %broadcast.1 = f32[28,14,672] broadcast(%param_2), dimensions={2}
305   %bitcast.4 = f32[28,14,672] bitcast(%multiply.1)
306   %multiply.2 = f32[28,14,672] multiply(%broadcast.1, %bitcast.4)
307   %bitcast.2 = f32[392,672] bitcast(%multiply.2)
308   %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
309   ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
310 }
311 
312 ENTRY %main {
313   %param_0 = f16[2,14,14,672] parameter(0)
314   %param_1 = f32[672] parameter(1)
315   %param_2 = f32[672] parameter(2)
316   ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2), kind=kInput, calls=%fused_computation.4d
317 }
318 )";
319 
320   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
321   EXPECT_FALSE(FusionBitcastLift().Run(module.get()).ValueOrDie());
322 }
323 
TEST_F(FusionBitcastLiftTest,ConstantBitcastTest)324 TEST_F(FusionBitcastLiftTest, ConstantBitcastTest) {
325   const char* hlo_text = R"(
326 HloModule mod
327 
328 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
329   %scalar_lhs.1 = f32[] parameter(0)
330   %scalar_rhs.1 = f32[] parameter(1)
331   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
332 }
333 
334 %fused_computation (param_0: f16[392,672], param_1: f32[1]) -> (f32[672], f32[672]) {
335   %param_0 = f16[392,672] parameter(0)
336   %convert = f32[392,672] convert(%param_0)
337 
338   %param_1 = f32[1] parameter(1)
339   %constant_0 = f32[1] constant({1.2})
340   %add = f32[1] add(%constant_0, %param_1)
341   %bitcast.2 = f32[] bitcast(%add)
342 
343   %reduce.1 = f32[672]{0} reduce(%convert, %bitcast.2), dimensions={0}, to_apply=%scalar_add_computation
344   %multiply = f32[392,672] multiply(%convert, %convert)
345   %reduce.2 = f32[672]{0} reduce(%multiply, %bitcast.2), dimensions={0}, to_apply=%scalar_add_computation
346   ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
347 }
348 
349 ENTRY %main {
350   %param_0 = f16[392,672] parameter(0)
351   %param_1 = f32[1] parameter(1)
352   ROOT %fusion = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1), kind=kInput, calls=%fused_computation
353 }
354 )";
355   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
356   EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
357   // Remove the old fusion not used anymore.
358   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
359 
360   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
361 
362   StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
363                                                  R"(
364 ; CHECK-LABEL: %fused_computation
365 ; CHECK:         f16[392,672]{1,0} parameter(0)
366 ; CHECK-COUNT-1: bitcast(
367 ; CHECK-NOT:     bitcast(
368 ; CHECK-LABEL: ENTRY %main
369 ; CHECK-NEXT:    f16[392,672]{1,0} parameter(0)
370 ; CHECK-NEXT:    f32[1]{0} parameter(1)
371 ; CHECK-NEXT:    bitcast(
372 ; CHECK-NEXT:    fusion(
373       )");
374   EXPECT_TRUE(filecheck_result.status().ok());
375   EXPECT_TRUE(filecheck_result.ValueOrDie());
376 }
377 
TEST_F(FusionBitcastLiftTest,Swish1Test)378 TEST_F(FusionBitcastLiftTest, Swish1Test) {
379   const char* hlo_text = R"(
380 HloModule mod
381 
382 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
383   %scalar_lhs.1 = f32[] parameter(0)
384   %scalar_rhs.1 = f32[] parameter(1)
385   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
386 }
387 
388 %fused_computation (param_0.90: f32[672], param_1.127: f16[2,14,14,672], param_2.77: f16[2,14,14,672], param_3.57: f16[2,14,14,672], param_4.57: f32[672], param_5.63: f32[672], param_6.44: f32[672]) -> (f32[672], f32[672]) {
389   %param_2.77 = f16[2,14,14,672]{3,2,1,0} parameter(2)
390   %param_3.57 = f16[2,14,14,672]{3,2,1,0} parameter(3)
391   %constant_153 = f16[] constant(1)
392   %broadcast.174 = f16[2,14,14,672]{3,2,1,0} broadcast(f16[] %constant_153), dimensions={}
393   %param_1.127 = f16[2,14,14,672]{3,2,1,0} parameter(1)
394   %convert.46 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %param_1.127)
395   %param_0.90 = f32[672]{0} parameter(0)
396   %constant_77_clone_1 = f32[] constant(9.96492327e-06)
397   %broadcast.173 = f32[672]{0} broadcast(f32[] %constant_77_clone_1), dimensions={}
398   %multiply.155 = f32[672]{0} multiply(f32[672]{0} %param_0.90, f32[672]{0} %broadcast.173)
399   %broadcast.172 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.155), dimensions={3}
400   %subtract.55 = f32[2,14,14,672]{3,2,1,0} subtract(f32[2,14,14,672]{3,2,1,0} %convert.46, f32[2,14,14,672]{3,2,1,0} %broadcast.172)
401   %param_6.44 = f32[672]{0} parameter(6)
402   %multiply.154 = f32[672]{0} multiply(f32[672]{0} %param_6.44, f32[672]{0} %broadcast.173)
403   %multiply.153 = f32[672]{0} multiply(f32[672]{0} %multiply.155, f32[672]{0} %multiply.155)
404   %subtract.54 = f32[672]{0} subtract(f32[672]{0} %multiply.154, f32[672]{0} %multiply.153)
405   %constant_151 = f32[] constant(0.001)
406   %broadcast.171 = f32[672]{0} broadcast(f32[] %constant_151), dimensions={}
407   %add.50 = f32[672]{0} add(f32[672]{0} %subtract.54, f32[672]{0} %broadcast.171)
408   %rsqrt.23 = f32[672]{0} rsqrt(f32[672]{0} %add.50)
409   %broadcast.170 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %rsqrt.23), dimensions={3}
410   %multiply.152 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %subtract.55, f32[2,14,14,672]{3,2,1,0} %broadcast.170)
411   %param_5.63 = f32[672]{0} parameter(5)
412   %broadcast.169 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_5.63), dimensions={3}
413   %multiply.151 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %multiply.152, f32[2,14,14,672]{3,2,1,0} %broadcast.169)
414   %param_4.57 = f32[672]{0} parameter(4)
415   %broadcast.168 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_4.57), dimensions={3}
416   %add.48 = f32[2,14,14,672]{3,2,1,0} add(f32[2,14,14,672]{3,2,1,0} %multiply.151, f32[2,14,14,672]{3,2,1,0} %broadcast.168)
417   %convert.45 = f16[2,14,14,672]{3,2,1,0} convert(f32[2,14,14,672]{3,2,1,0} %add.48)
418   %subtract.53 = f16[2,14,14,672]{3,2,1,0} subtract(f16[2,14,14,672]{3,2,1,0} %broadcast.174, f16[2,14,14,672]{3,2,1,0} %param_3.57)
419   %multiply.150 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %convert.45, f16[2,14,14,672]{3,2,1,0} %subtract.53)
420   %add.47 = f16[2,14,14,672]{3,2,1,0} add(f16[2,14,14,672]{3,2,1,0} %broadcast.174, f16[2,14,14,672]{3,2,1,0} %multiply.150)
421   %multiply.149 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %param_3.57, f16[2,14,14,672]{3,2,1,0} %add.47)
422   %multiply.148 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %param_2.77, f16[2,14,14,672]{3,2,1,0} %multiply.149)
423   %convert.10 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %multiply.148)
424   %bitcast.21 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %convert.10)
425   %constant_57 = f32[] constant(0)
426   %reduce.9 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.21, f32[] %constant_57), dimensions={0}, to_apply=%scalar_add_computation
427   %multiply.30.clone.1 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %convert.10, f32[2,14,14,672]{3,2,1,0} %subtract.55)
428   %bitcast.20.clone.1 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %multiply.30.clone.1)
429   %reduce.8.clone.1 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.20.clone.1, f32[] %constant_57), dimensions={0}, to_apply=%scalar_add_computation
430   ROOT %tuple.9 = (f32[672]{0}, f32[672]{0}) tuple(f32[672]{0} %reduce.9, f32[672]{0} %reduce.8.clone.1)
431 }
432 
433 ENTRY %main {
434   %param_0 = f32[672]{0} parameter(0)
435   %param_1 = f16[2,14,14,672]{3,2,1,0} parameter(1)
436   %param_2 = f16[2,14,14,672]{3,2,1,0} parameter(2)
437   %param_3 = f16[2,14,14,672]{3,2,1,0} parameter(3)
438   %param_4 = f32[672]{0} parameter(4)
439   %param_5 = f32[672]{0} parameter(5)
440   %param_6 = f32[672]{0} parameter(6)
441 
442   ROOT %fusion = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2, %param_3, %param_4, %param_5, %param_6), kind=kInput, calls=%fused_computation
443 }
444 )";
445   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
446   EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
447   // Remove the old fusion not used anymore.
448   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
449 
450   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
451 
452   StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
453                                                  R"(
454 ; CHECK-LABEL: %fused_computation
455 ; CHECK-COUNT-6: bitcast(
456 ; CHECK-NOT:     bitcast(
457 ; CHECK-LABEL: ENTRY %main
458 ; CHECK-COUNT-3: bitcast(
459 ; CHECK-NOT:     bitcast(
460 ; CHECK:         fusion(
461       )");
462   EXPECT_TRUE(filecheck_result.status().ok());
463   EXPECT_TRUE(filecheck_result.ValueOrDie());
464 }
465 
TEST_F(FusionBitcastLiftTest,Swish2Test)466 TEST_F(FusionBitcastLiftTest, Swish2Test) {
467   const char* hlo_text = R"(
468 HloModule mod
469 
470 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
471   %scalar_lhs.1 = f32[] parameter(0)
472   %scalar_rhs.1 = f32[] parameter(1)
473   ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
474 }
475 
476 
477 %fused_computation (param_0.95: f32[672], param_1.128: f16[2,14,14,672], param_2.81: f16[2,14,14,672], param_3.66: f32[672], param_4.61: f32[672], param_5.62: f32[672]) -> (f32[672], f32[672]) {
478   %param_2.81 = f16[2,14,14,672]{3,2,1,0} parameter(2)
479   %constant_211 = f16[] constant(1)
480   %broadcast.288 = f16[2,14,14,672]{3,2,1,0} broadcast(f16[] %constant_211), dimensions={}
481   %param_1.128 = f16[2,14,14,672]{3,2,1,0} parameter(1)
482   %convert.74 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %param_1.128)
483   %param_0.95 = f32[672]{0} parameter(0)
484   %constant_77 = f32[] constant(9.96492327e-06)
485   %broadcast.287 = f32[672]{0} broadcast(f32[] %constant_77), dimensions={}
486   %multiply.253 = f32[672]{0} multiply(f32[672]{0} %param_0.95, f32[672]{0} %broadcast.287)
487   %broadcast.286 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.253), dimensions={3}
488   %subtract.92 = f32[2,14,14,672]{3,2,1,0} subtract(f32[2,14,14,672]{3,2,1,0} %convert.74, f32[2,14,14,672]{3,2,1,0} %broadcast.286)
489   %param_5.62 = f32[672]{0} parameter(5)
490   %multiply.252 = f32[672]{0} multiply(f32[672]{0} %param_5.62, f32[672]{0} %broadcast.287)
491   %multiply.250 = f32[672]{0} multiply(f32[672]{0} %multiply.253, f32[672]{0} %multiply.253)
492   %subtract.91 = f32[672]{0} subtract(f32[672]{0} %multiply.252, f32[672]{0} %multiply.250)
493   %constant_208 = f32[] constant(0.001)
494   %broadcast.284 = f32[672]{0} broadcast(f32[] %constant_208), dimensions={}
495   %add.93 = f32[672]{0} add(f32[672]{0} %subtract.91, f32[672]{0} %broadcast.284)
496   %rsqrt.37 = f32[672]{0} rsqrt(f32[672]{0} %add.93)
497   %broadcast.283 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %rsqrt.37), dimensions={3}
498   %multiply.249 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %subtract.92, f32[2,14,14,672]{3,2,1,0} %broadcast.283)
499   %param_4.61 = f32[672]{0} parameter(4)
500   %broadcast.282 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_4.61), dimensions={3}
501   %multiply.248 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %multiply.249, f32[2,14,14,672]{3,2,1,0} %broadcast.282)
502   %param_3.66 = f32[672]{0} parameter(3)
503   %broadcast.281 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_3.66), dimensions={3}
504   %add.92 = f32[2,14,14,672]{3,2,1,0} add(f32[2,14,14,672]{3,2,1,0} %multiply.248, f32[2,14,14,672]{3,2,1,0} %broadcast.281)
505   %convert.73 = f16[2,14,14,672]{3,2,1,0} convert(f32[2,14,14,672]{3,2,1,0} %add.92)
506   %negate.14 = f16[2,14,14,672]{3,2,1,0} negate(f16[2,14,14,672]{3,2,1,0} %convert.73)
507   %exponential.12 = f16[2,14,14,672]{3,2,1,0} exponential(f16[2,14,14,672]{3,2,1,0} %negate.14)
508   %add.91 = f16[2,14,14,672]{3,2,1,0} add(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %exponential.12)
509   %divide.22 = f16[2,14,14,672]{3,2,1,0} divide(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %add.91)
510   %subtract.88 = f16[2,14,14,672]{3,2,1,0} subtract(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %divide.22)
511   %multiply.241 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %convert.73, f16[2,14,14,672]{3,2,1,0} %subtract.88)
512   %add.87 = f16[2,14,14,672]{3,2,1,0} add(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %multiply.241)
513   %multiply.240 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %divide.22, f16[2,14,14,672]{3,2,1,0} %add.87)
514   %multiply.239 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %param_2.81, f16[2,14,14,672]{3,2,1,0} %multiply.240)
515   %convert.9 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %multiply.239)
516   %multiply.30 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %convert.9, f32[2,14,14,672]{3,2,1,0} %subtract.92)
517   %bitcast.20 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %multiply.30)
518   %constant_58 = f32[] constant(0)
519   %reduce.8 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.20, f32[] %constant_58), dimensions={0}, to_apply=%scalar_add_computation
520   %bitcast.21.clone.1 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %convert.9)
521   %reduce.9.clone.1 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.21.clone.1, f32[] %constant_58), dimensions={0}, to_apply=%scalar_add_computation
522   ROOT %tuple.9 = (f32[672]{0}, f32[672]{0}) tuple(f32[672]{0} %reduce.8, f32[672]{0} %reduce.9.clone.1)
523 }
524 
525 ENTRY %main {
526   %param_0 = f32[672]{0} parameter(0)
527   %param_1 = f16[2,14,14,672]{3,2,1,0} parameter(1)
528   %param_2 = f16[2,14,14,672]{3,2,1,0} parameter(2)
529   %param_3 = f32[672]{0} parameter(3)
530   %param_4 = f32[672]{0} parameter(4)
531   %param_5 = f32[672]{0} parameter(5)
532 
533   ROOT %fusion = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2, %param_3, %param_4, %param_5), kind=kInput, calls=%fused_computation
534 }
535 )";
536   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
537   EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
538   // Remove the old fusion not used anymore.
539   EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
540 
541   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
542 
543   StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
544                                                  R"(
545 ; CHECK-LABEL: %fused_computation
546 ; CHECK-COUNT-8: bitcast(
547 ; CHECK-NOT:     bitcast(
548 ; CHECK-LABEL: ENTRY %main
549 ; CHECK-COUNT-2: bitcast(
550 ; CHECK-NOT:     bitcast(
551 ; CHECK:         fusion(
552       )");
553   EXPECT_TRUE(filecheck_result.status().ok());
554   EXPECT_TRUE(filecheck_result.ValueOrDie());
555 }
556 
TEST_F(FusionBitcastLiftTest,LayoutChangeNotSupported)557 TEST_F(FusionBitcastLiftTest, LayoutChangeNotSupported) {
558   const char* hlo_text = R"(
559 HloModule bla
560 
561 add {
562   param0 = f32[] parameter(0)
563   param1 = f32[] parameter(1)
564   ROOT add = f32[] add(param0, param1)
565 }
566 
567 fused_computation {
568   param_1.11485 = f32[1,1,1536,3072]{3,2,1,0} parameter(1)
569   copy.1383 = f32[1,1,1536,3072]{1,0,2,3} copy(param_1.11485)
570   param_0.7122 = f32[3072]{0} parameter(0)
571   constant.9031 = f32[] constant(0.000651041686)
572   broadcast.9040 = f32[3072]{0} broadcast(constant.9031), dimensions={}
573   multiply.7225 = f32[3072]{0} multiply(param_0.7122, broadcast.9040)
574   broadcast.9039 = f32[1,1,1536,3072]{1,0,2,3} broadcast(multiply.7225), dimensions={3}
575   subtract.940 = f32[1,1,1536,3072]{1,0,2,3} subtract(copy.1383, broadcast.9039)
576   multiply.7224 = f32[1,1,1536,3072]{1,0,2,3} multiply(subtract.940, subtract.940)
577   bitcast.3805 = f32[3072,1536]{1,0} bitcast(multiply.7224)
578   constant.25971 = f32[] constant(0)
579   ROOT reduce.790 = f32[3072]{0} reduce(bitcast.3805, constant.25971), dimensions={1}, to_apply=add
580 }
581 
582 ENTRY entry {
583   param_0.0 = f32[3072]{0} parameter(0)
584   param_1.0 = f32[1,1,1536,3072]{3,2,1,0} parameter(1)
585   ROOT fusion = f32[3072]{0} fusion(param_0.0, param_1.0), kind=kInput, calls=fused_computation
586 }
587 )";
588   auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
589   EXPECT_FALSE(FusionBitcastLift().Run(module.get()).ValueOrDie());
590 }
591 
592 }  // namespace
593 }  // namespace gpu
594 }  // namespace xla
595