1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/fusion_bitcast_lift.h"
17
18 #include <vector>
19
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/service/hlo_dce.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/tests/filecheck.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25
26 // TODO(b/210165681): The tests in this file are fragile to HLO op names.
27
28 namespace xla {
29 namespace gpu {
30 namespace {
31
32 class FusionBitcastLiftTest : public HloTestBase {};
33
34 // Tests that we lift bitcast outside the fusion.
35 //
36 // This test MultiOutputFusion, multiple consecutive lift, bitcast
37 // with multiple users and bitcast that are used many time by the same
38 // user. This is a real kernel from Efficient Net, but with smaller
39 // shape to speed up tests.
40 //
41 // Input graph:
42 // Fusion 4d input, 2 1d output
43 //
44 // After optimization, the graph is:
45 // Bitcast 4d -> 2d
46 // |
47 // Fusion 2d input, 2x1d outputs.
TEST_F(FusionBitcastLiftTest,NoBroadcastTest)48 TEST_F(FusionBitcastLiftTest, NoBroadcastTest) {
49 const char* hlo_text = R"(
50 HloModule mod
51
52 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
53 %scalar_lhs.1 = f32[] parameter(0)
54 %scalar_rhs.1 = f32[] parameter(1)
55 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
56 }
57
58 %fused_computation.4d (param_0: f16[2,14,14,672]) -> (f32[672], f32[672]) {
59 %param_0 = f16[2,14,14,672] parameter(0)
60 %convert = f32[2,14,14,672] convert(%param_0)
61 %bitcast.1 = f32[392,672] bitcast(%convert)
62 %constant_0 = f32[] constant(0)
63 %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
64 %multiply = f32[2,14,14,672] multiply(%convert, %convert)
65 %bitcast.2 = f32[392,672] bitcast(%multiply)
66 %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
67 ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
68 }
69
70 ENTRY %main {
71 %param_0 = f16[2,14,14,672] parameter(0)
72 ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0), kind=kInput, calls=%fused_computation.4d
73 }
74 )";
75 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
76 EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
77 // Remove the old fusion not used anymore.
78 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
79
80 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
81
82 StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
83 R"(
84 ; CHECK-LABEL: %fused_computation
85 ; CHECK: f16[392,672]{1,0} parameter(0)
86 ; CHECK-NOT: parameter
87 ; CHECK-NOT: bitcast
88 ; CHECK-LABEL: ENTRY %main
89 ; CHECK-NEXT: f16[2,14,14,672]{3,2,1,0} parameter(0)
90 ; CHECK-NEXT: bitcast(
91 ; CHECK-NEXT: fusion(
92 )");
93 EXPECT_TRUE(filecheck_result.status().ok());
94 EXPECT_TRUE(filecheck_result.ValueOrDie());
95 }
96
97 // Tests that we lift bitcast outside the fusion when scalar broadcasting are
98 // present.
99 //
100 // Input graph:
101 // Fusion 1x4d and 1x0d inputs, 2 1d output
102 //
103 // After optimization, the graph is:
104 // Bitcast 4d -> 2d
105 // |
106 // Fusion 1x2d and 1x0d inputs, 2 1d output
107 // Inside the fusion, there is a bitcast left after the broadcast.
TEST_F(FusionBitcastLiftTest,ScalarBroadcastTest)108 TEST_F(FusionBitcastLiftTest, ScalarBroadcastTest) {
109 const char* hlo_text = R"(
110 HloModule mod
111
112 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
113 %scalar_lhs.1 = f32[] parameter(0)
114 %scalar_rhs.1 = f32[] parameter(1)
115 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
116 }
117
118 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[]) -> (f32[672], f32[672]) {
119 %param_0 = f16[2,14,14,672] parameter(0)
120 %convert = f32[2,14,14,672] convert(%param_0)
121 %bitcast.1 = f32[392,672] bitcast(%convert)
122 %constant_0 = f32[] constant(0)
123 %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
124 %param_1 = f32[] parameter(1)
125 %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={}
126 %multiply = f32[2,14,14,672] multiply(%convert, %broadcast)
127 %bitcast.2 = f32[392,672] bitcast(%multiply)
128 %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
129 ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
130 }
131
132 ENTRY %main {
133 %param_0 = f16[2,14,14,672] parameter(0)
134 %param_1 = f32[] parameter(1)
135 ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1), kind=kInput, calls=%fused_computation.4d
136 }
137 )";
138 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
139 EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
140 // Remove the old fusion not used anymore.
141 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
142
143 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
144
145 StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
146 R"(
147 ; CHECK-LABEL: %fused_computation
148 ; CHECK: f16[392,672]{1,0} parameter(0)
149 ; CHECK: f32[] parameter(1)
150 ; CHECK-NOT: parameter
151 ; CHECK-NOT: bitcast
152 ; CHECK: %broadcast.1 =
153 ; CHECK-NEXT: bitcast(f32[2,14,14,672]{3,2,1,0} %broadcast.1)
154 ; CHECK-NOT: bitcast(
155 ; CHECK-LABEL: ENTRY %main
156 ; CHECK-NEXT: f16[2,14,14,672]{3,2,1,0} parameter(0)
157 ; CHECK-NEXT: bitcast(
158 ; CHECK-NEXT: %param_1.1 = f32[] parameter(1)
159 ; CHECK-NEXT: fusion(
160 )");
161 EXPECT_TRUE(filecheck_result.status().ok());
162 EXPECT_TRUE(filecheck_result.ValueOrDie());
163 }
164
TEST_F(FusionBitcastLiftTest,RowBroadcastTest)165 TEST_F(FusionBitcastLiftTest, RowBroadcastTest) {
166 const char* hlo_text = R"(
167 HloModule mod
168
169 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
170 %scalar_lhs.1 = f32[] parameter(0)
171 %scalar_rhs.1 = f32[] parameter(1)
172 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
173 }
174
175 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[672]) -> (f32[672], f32[672]) {
176 %param_0 = f16[2,14,14,672] parameter(0)
177 %convert = f32[2,14,14,672] convert(%param_0)
178 %bitcast.1 = f32[392,672] bitcast(%convert)
179 %constant_0 = f32[] constant(0)
180 %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
181 %param_1 = f32[672] parameter(1)
182 %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={3}
183 %multiply = f32[2,14,14,672] multiply(%convert, %broadcast)
184 %bitcast.2 = f32[392,672] bitcast(%multiply)
185 %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
186 ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
187 }
188
189 ENTRY %main {
190 %param_0 = f16[2,14,14,672] parameter(0)
191 %param_1 = f32[672] parameter(1)
192 ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1), kind=kInput, calls=%fused_computation.4d
193 }
194 )";
195 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
196 EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
197 // Remove the old fusion not used anymore.
198 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
199
200 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
201
202 StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
203 R"(
204 ; CHECK-LABEL: %fused_computation
205 ; CHECK: f16[392,672]{1,0} parameter(0)
206 ; CHECK: f32[672]{0} parameter(1)
207 ; CHECK-NOT: parameter
208 ; CHECK-NOT: bitcast
209 ; CHECK: %broadcast.1
210 ; CHECK: bitcast(f32[2,14,14,672]{3,2,1,0} %broadcast.1)
211 ; CHECK-NOT: bitcast(
212 ; CHECK-LABEL: ENTRY %main
213 ; CHECK-NEXT: f16[2,14,14,672]{3,2,1,0} parameter(0)
214 ; CHECK-NEXT: bitcast(
215 ; CHECK-NEXT: %param_1.1 = f32[672]{0} parameter(1)
216 ; CHECK-NEXT: fusion(
217 )");
218 EXPECT_TRUE(filecheck_result.status().ok());
219 EXPECT_TRUE(filecheck_result.ValueOrDie());
220 }
221
TEST_F(FusionBitcastLiftTest,ScalarAndRowBroadcastTest)222 TEST_F(FusionBitcastLiftTest, ScalarAndRowBroadcastTest) {
223 const char* hlo_text = R"(
224 HloModule mod
225
226 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
227 %scalar_lhs.1 = f32[] parameter(0)
228 %scalar_rhs.1 = f32[] parameter(1)
229 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
230 }
231
232 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[672], param_2: f32[]) -> (f32[672], f32[672]) {
233 %param_0 = f16[2,14,14,672] parameter(0)
234 %convert = f32[2,14,14,672] convert(%param_0)
235 %bitcast.1 = f32[392,672] bitcast(%convert)
236 %constant_0 = f32[] constant(0)
237 %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
238 %param_1 = f32[672] parameter(1)
239 %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={3}
240 %multiply.1 = f32[2,14,14,672] multiply(%convert, %broadcast)
241 %param_2 = f32[] parameter(2)
242 %broadcast.1 = f32[2,14,14,672] broadcast(%param_2), dimensions={}
243 %multiply.2 = f32[2,14,14,672] multiply(%broadcast.1, %multiply.1)
244 %bitcast.2 = f32[392,672] bitcast(%multiply.2)
245 %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
246 ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
247 }
248
249 ENTRY %main {
250 %param_0 = f16[2,14,14,672] parameter(0)
251 %param_1 = f32[672] parameter(1)
252 %param_2 = f32[] parameter(2)
253 ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2), kind=kInput, calls=%fused_computation.4d
254 }
255 )";
256
257 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
258 EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
259 // Remove the old fusion not used anymore.
260 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
261
262 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
263
264 StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
265 R"(
266 ; CHECK-LABEL: %fused_computation
267 ; CHECK-NOT: bitcast
268 ; CHECK: f32[2,14,14,672]{3,2,1,0} broadcast(
269 ; CHECK-NEXT: f32[392,672]{1,0} bitcast(
270 ; CHECK-NOT: bitcast
271 ; CHECK: f32[2,14,14,672]{3,2,1,0} broadcast(
272 ; CHECK-NEXT: f32[392,672]{1,0} bitcast(
273 ; CHECK-NOT: bitcast(
274 ; CHECK-LABEL: ENTRY %main
275 ; CHECK-NOT: bitcast(
276 ; CHECK: bitcast(f16[2,14,14,672]{3,2,1,0} %param_0
277 ; CHECK-NOT: bitcast(
278 )");
279 EXPECT_TRUE(filecheck_result.status().ok());
280 EXPECT_TRUE(filecheck_result.ValueOrDie());
281 }
282
283 // To trigger the bitcast same pattern check.
TEST_F(FusionBitcastLiftTest,StrangeBitcastBroadcastTest)284 TEST_F(FusionBitcastLiftTest, StrangeBitcastBroadcastTest) {
285 const char* hlo_text = R"(
286 HloModule mod
287
288 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
289 %scalar_lhs.1 = f32[] parameter(0)
290 %scalar_rhs.1 = f32[] parameter(1)
291 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
292 }
293
294 %fused_computation.4d (param_0: f16[2,14,14,672], param_1: f32[672], param_2: f32[672]) -> (f32[672], f32[672]) {
295 %param_0 = f16[2,14,14,672] parameter(0)
296 %convert = f32[2,14,14,672] convert(%param_0)
297 %bitcast.1 = f32[392,672] bitcast(%convert)
298 %constant_0 = f32[] constant(0)
299 %reduce.1 = f32[672]{0} reduce(%bitcast.1, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
300 %param_1 = f32[672] parameter(1)
301 %broadcast = f32[2,14,14,672] broadcast(%param_1), dimensions={3}
302 %multiply.1 = f32[2,14,14,672] multiply(%convert, %broadcast)
303 %param_2 = f32[672] parameter(2)
304 %broadcast.1 = f32[28,14,672] broadcast(%param_2), dimensions={2}
305 %bitcast.4 = f32[28,14,672] bitcast(%multiply.1)
306 %multiply.2 = f32[28,14,672] multiply(%broadcast.1, %bitcast.4)
307 %bitcast.2 = f32[392,672] bitcast(%multiply.2)
308 %reduce.2 = f32[672]{0} reduce(%bitcast.2, %constant_0), dimensions={0}, to_apply=%scalar_add_computation
309 ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
310 }
311
312 ENTRY %main {
313 %param_0 = f16[2,14,14,672] parameter(0)
314 %param_1 = f32[672] parameter(1)
315 %param_2 = f32[672] parameter(2)
316 ROOT %fusion.4d = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2), kind=kInput, calls=%fused_computation.4d
317 }
318 )";
319
320 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
321 EXPECT_FALSE(FusionBitcastLift().Run(module.get()).ValueOrDie());
322 }
323
TEST_F(FusionBitcastLiftTest,ConstantBitcastTest)324 TEST_F(FusionBitcastLiftTest, ConstantBitcastTest) {
325 const char* hlo_text = R"(
326 HloModule mod
327
328 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
329 %scalar_lhs.1 = f32[] parameter(0)
330 %scalar_rhs.1 = f32[] parameter(1)
331 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
332 }
333
334 %fused_computation (param_0: f16[392,672], param_1: f32[1]) -> (f32[672], f32[672]) {
335 %param_0 = f16[392,672] parameter(0)
336 %convert = f32[392,672] convert(%param_0)
337
338 %param_1 = f32[1] parameter(1)
339 %constant_0 = f32[1] constant({1.2})
340 %add = f32[1] add(%constant_0, %param_1)
341 %bitcast.2 = f32[] bitcast(%add)
342
343 %reduce.1 = f32[672]{0} reduce(%convert, %bitcast.2), dimensions={0}, to_apply=%scalar_add_computation
344 %multiply = f32[392,672] multiply(%convert, %convert)
345 %reduce.2 = f32[672]{0} reduce(%multiply, %bitcast.2), dimensions={0}, to_apply=%scalar_add_computation
346 ROOT %tuple = (f32[672]{0}, f32[672]{0}) tuple(%reduce.1, %reduce.2)
347 }
348
349 ENTRY %main {
350 %param_0 = f16[392,672] parameter(0)
351 %param_1 = f32[1] parameter(1)
352 ROOT %fusion = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1), kind=kInput, calls=%fused_computation
353 }
354 )";
355 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
356 EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
357 // Remove the old fusion not used anymore.
358 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
359
360 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
361
362 StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
363 R"(
364 ; CHECK-LABEL: %fused_computation
365 ; CHECK: f16[392,672]{1,0} parameter(0)
366 ; CHECK-COUNT-1: bitcast(
367 ; CHECK-NOT: bitcast(
368 ; CHECK-LABEL: ENTRY %main
369 ; CHECK-NEXT: f16[392,672]{1,0} parameter(0)
370 ; CHECK-NEXT: f32[1]{0} parameter(1)
371 ; CHECK-NEXT: bitcast(
372 ; CHECK-NEXT: fusion(
373 )");
374 EXPECT_TRUE(filecheck_result.status().ok());
375 EXPECT_TRUE(filecheck_result.ValueOrDie());
376 }
377
TEST_F(FusionBitcastLiftTest,Swish1Test)378 TEST_F(FusionBitcastLiftTest, Swish1Test) {
379 const char* hlo_text = R"(
380 HloModule mod
381
382 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
383 %scalar_lhs.1 = f32[] parameter(0)
384 %scalar_rhs.1 = f32[] parameter(1)
385 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
386 }
387
388 %fused_computation (param_0.90: f32[672], param_1.127: f16[2,14,14,672], param_2.77: f16[2,14,14,672], param_3.57: f16[2,14,14,672], param_4.57: f32[672], param_5.63: f32[672], param_6.44: f32[672]) -> (f32[672], f32[672]) {
389 %param_2.77 = f16[2,14,14,672]{3,2,1,0} parameter(2)
390 %param_3.57 = f16[2,14,14,672]{3,2,1,0} parameter(3)
391 %constant_153 = f16[] constant(1)
392 %broadcast.174 = f16[2,14,14,672]{3,2,1,0} broadcast(f16[] %constant_153), dimensions={}
393 %param_1.127 = f16[2,14,14,672]{3,2,1,0} parameter(1)
394 %convert.46 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %param_1.127)
395 %param_0.90 = f32[672]{0} parameter(0)
396 %constant_77_clone_1 = f32[] constant(9.96492327e-06)
397 %broadcast.173 = f32[672]{0} broadcast(f32[] %constant_77_clone_1), dimensions={}
398 %multiply.155 = f32[672]{0} multiply(f32[672]{0} %param_0.90, f32[672]{0} %broadcast.173)
399 %broadcast.172 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.155), dimensions={3}
400 %subtract.55 = f32[2,14,14,672]{3,2,1,0} subtract(f32[2,14,14,672]{3,2,1,0} %convert.46, f32[2,14,14,672]{3,2,1,0} %broadcast.172)
401 %param_6.44 = f32[672]{0} parameter(6)
402 %multiply.154 = f32[672]{0} multiply(f32[672]{0} %param_6.44, f32[672]{0} %broadcast.173)
403 %multiply.153 = f32[672]{0} multiply(f32[672]{0} %multiply.155, f32[672]{0} %multiply.155)
404 %subtract.54 = f32[672]{0} subtract(f32[672]{0} %multiply.154, f32[672]{0} %multiply.153)
405 %constant_151 = f32[] constant(0.001)
406 %broadcast.171 = f32[672]{0} broadcast(f32[] %constant_151), dimensions={}
407 %add.50 = f32[672]{0} add(f32[672]{0} %subtract.54, f32[672]{0} %broadcast.171)
408 %rsqrt.23 = f32[672]{0} rsqrt(f32[672]{0} %add.50)
409 %broadcast.170 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %rsqrt.23), dimensions={3}
410 %multiply.152 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %subtract.55, f32[2,14,14,672]{3,2,1,0} %broadcast.170)
411 %param_5.63 = f32[672]{0} parameter(5)
412 %broadcast.169 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_5.63), dimensions={3}
413 %multiply.151 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %multiply.152, f32[2,14,14,672]{3,2,1,0} %broadcast.169)
414 %param_4.57 = f32[672]{0} parameter(4)
415 %broadcast.168 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_4.57), dimensions={3}
416 %add.48 = f32[2,14,14,672]{3,2,1,0} add(f32[2,14,14,672]{3,2,1,0} %multiply.151, f32[2,14,14,672]{3,2,1,0} %broadcast.168)
417 %convert.45 = f16[2,14,14,672]{3,2,1,0} convert(f32[2,14,14,672]{3,2,1,0} %add.48)
418 %subtract.53 = f16[2,14,14,672]{3,2,1,0} subtract(f16[2,14,14,672]{3,2,1,0} %broadcast.174, f16[2,14,14,672]{3,2,1,0} %param_3.57)
419 %multiply.150 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %convert.45, f16[2,14,14,672]{3,2,1,0} %subtract.53)
420 %add.47 = f16[2,14,14,672]{3,2,1,0} add(f16[2,14,14,672]{3,2,1,0} %broadcast.174, f16[2,14,14,672]{3,2,1,0} %multiply.150)
421 %multiply.149 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %param_3.57, f16[2,14,14,672]{3,2,1,0} %add.47)
422 %multiply.148 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %param_2.77, f16[2,14,14,672]{3,2,1,0} %multiply.149)
423 %convert.10 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %multiply.148)
424 %bitcast.21 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %convert.10)
425 %constant_57 = f32[] constant(0)
426 %reduce.9 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.21, f32[] %constant_57), dimensions={0}, to_apply=%scalar_add_computation
427 %multiply.30.clone.1 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %convert.10, f32[2,14,14,672]{3,2,1,0} %subtract.55)
428 %bitcast.20.clone.1 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %multiply.30.clone.1)
429 %reduce.8.clone.1 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.20.clone.1, f32[] %constant_57), dimensions={0}, to_apply=%scalar_add_computation
430 ROOT %tuple.9 = (f32[672]{0}, f32[672]{0}) tuple(f32[672]{0} %reduce.9, f32[672]{0} %reduce.8.clone.1)
431 }
432
433 ENTRY %main {
434 %param_0 = f32[672]{0} parameter(0)
435 %param_1 = f16[2,14,14,672]{3,2,1,0} parameter(1)
436 %param_2 = f16[2,14,14,672]{3,2,1,0} parameter(2)
437 %param_3 = f16[2,14,14,672]{3,2,1,0} parameter(3)
438 %param_4 = f32[672]{0} parameter(4)
439 %param_5 = f32[672]{0} parameter(5)
440 %param_6 = f32[672]{0} parameter(6)
441
442 ROOT %fusion = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2, %param_3, %param_4, %param_5, %param_6), kind=kInput, calls=%fused_computation
443 }
444 )";
445 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
446 EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
447 // Remove the old fusion not used anymore.
448 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
449
450 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
451
452 StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
453 R"(
454 ; CHECK-LABEL: %fused_computation
455 ; CHECK-COUNT-6: bitcast(
456 ; CHECK-NOT: bitcast(
457 ; CHECK-LABEL: ENTRY %main
458 ; CHECK-COUNT-3: bitcast(
459 ; CHECK-NOT: bitcast(
460 ; CHECK: fusion(
461 )");
462 EXPECT_TRUE(filecheck_result.status().ok());
463 EXPECT_TRUE(filecheck_result.ValueOrDie());
464 }
465
TEST_F(FusionBitcastLiftTest,Swish2Test)466 TEST_F(FusionBitcastLiftTest, Swish2Test) {
467 const char* hlo_text = R"(
468 HloModule mod
469
470 %scalar_add_computation (scalar_lhs.1: f32[], scalar_rhs.1: f32[]) -> f32[] {
471 %scalar_lhs.1 = f32[] parameter(0)
472 %scalar_rhs.1 = f32[] parameter(1)
473 ROOT %add.5 = f32[] add(f32[] %scalar_lhs.1, f32[] %scalar_rhs.1)
474 }
475
476
477 %fused_computation (param_0.95: f32[672], param_1.128: f16[2,14,14,672], param_2.81: f16[2,14,14,672], param_3.66: f32[672], param_4.61: f32[672], param_5.62: f32[672]) -> (f32[672], f32[672]) {
478 %param_2.81 = f16[2,14,14,672]{3,2,1,0} parameter(2)
479 %constant_211 = f16[] constant(1)
480 %broadcast.288 = f16[2,14,14,672]{3,2,1,0} broadcast(f16[] %constant_211), dimensions={}
481 %param_1.128 = f16[2,14,14,672]{3,2,1,0} parameter(1)
482 %convert.74 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %param_1.128)
483 %param_0.95 = f32[672]{0} parameter(0)
484 %constant_77 = f32[] constant(9.96492327e-06)
485 %broadcast.287 = f32[672]{0} broadcast(f32[] %constant_77), dimensions={}
486 %multiply.253 = f32[672]{0} multiply(f32[672]{0} %param_0.95, f32[672]{0} %broadcast.287)
487 %broadcast.286 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.253), dimensions={3}
488 %subtract.92 = f32[2,14,14,672]{3,2,1,0} subtract(f32[2,14,14,672]{3,2,1,0} %convert.74, f32[2,14,14,672]{3,2,1,0} %broadcast.286)
489 %param_5.62 = f32[672]{0} parameter(5)
490 %multiply.252 = f32[672]{0} multiply(f32[672]{0} %param_5.62, f32[672]{0} %broadcast.287)
491 %multiply.250 = f32[672]{0} multiply(f32[672]{0} %multiply.253, f32[672]{0} %multiply.253)
492 %subtract.91 = f32[672]{0} subtract(f32[672]{0} %multiply.252, f32[672]{0} %multiply.250)
493 %constant_208 = f32[] constant(0.001)
494 %broadcast.284 = f32[672]{0} broadcast(f32[] %constant_208), dimensions={}
495 %add.93 = f32[672]{0} add(f32[672]{0} %subtract.91, f32[672]{0} %broadcast.284)
496 %rsqrt.37 = f32[672]{0} rsqrt(f32[672]{0} %add.93)
497 %broadcast.283 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %rsqrt.37), dimensions={3}
498 %multiply.249 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %subtract.92, f32[2,14,14,672]{3,2,1,0} %broadcast.283)
499 %param_4.61 = f32[672]{0} parameter(4)
500 %broadcast.282 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_4.61), dimensions={3}
501 %multiply.248 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %multiply.249, f32[2,14,14,672]{3,2,1,0} %broadcast.282)
502 %param_3.66 = f32[672]{0} parameter(3)
503 %broadcast.281 = f32[2,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_3.66), dimensions={3}
504 %add.92 = f32[2,14,14,672]{3,2,1,0} add(f32[2,14,14,672]{3,2,1,0} %multiply.248, f32[2,14,14,672]{3,2,1,0} %broadcast.281)
505 %convert.73 = f16[2,14,14,672]{3,2,1,0} convert(f32[2,14,14,672]{3,2,1,0} %add.92)
506 %negate.14 = f16[2,14,14,672]{3,2,1,0} negate(f16[2,14,14,672]{3,2,1,0} %convert.73)
507 %exponential.12 = f16[2,14,14,672]{3,2,1,0} exponential(f16[2,14,14,672]{3,2,1,0} %negate.14)
508 %add.91 = f16[2,14,14,672]{3,2,1,0} add(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %exponential.12)
509 %divide.22 = f16[2,14,14,672]{3,2,1,0} divide(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %add.91)
510 %subtract.88 = f16[2,14,14,672]{3,2,1,0} subtract(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %divide.22)
511 %multiply.241 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %convert.73, f16[2,14,14,672]{3,2,1,0} %subtract.88)
512 %add.87 = f16[2,14,14,672]{3,2,1,0} add(f16[2,14,14,672]{3,2,1,0} %broadcast.288, f16[2,14,14,672]{3,2,1,0} %multiply.241)
513 %multiply.240 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %divide.22, f16[2,14,14,672]{3,2,1,0} %add.87)
514 %multiply.239 = f16[2,14,14,672]{3,2,1,0} multiply(f16[2,14,14,672]{3,2,1,0} %param_2.81, f16[2,14,14,672]{3,2,1,0} %multiply.240)
515 %convert.9 = f32[2,14,14,672]{3,2,1,0} convert(f16[2,14,14,672]{3,2,1,0} %multiply.239)
516 %multiply.30 = f32[2,14,14,672]{3,2,1,0} multiply(f32[2,14,14,672]{3,2,1,0} %convert.9, f32[2,14,14,672]{3,2,1,0} %subtract.92)
517 %bitcast.20 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %multiply.30)
518 %constant_58 = f32[] constant(0)
519 %reduce.8 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.20, f32[] %constant_58), dimensions={0}, to_apply=%scalar_add_computation
520 %bitcast.21.clone.1 = f32[392,672]{1,0} bitcast(f32[2,14,14,672]{3,2,1,0} %convert.9)
521 %reduce.9.clone.1 = f32[672]{0} reduce(f32[392,672]{1,0} %bitcast.21.clone.1, f32[] %constant_58), dimensions={0}, to_apply=%scalar_add_computation
522 ROOT %tuple.9 = (f32[672]{0}, f32[672]{0}) tuple(f32[672]{0} %reduce.8, f32[672]{0} %reduce.9.clone.1)
523 }
524
525 ENTRY %main {
526 %param_0 = f32[672]{0} parameter(0)
527 %param_1 = f16[2,14,14,672]{3,2,1,0} parameter(1)
528 %param_2 = f16[2,14,14,672]{3,2,1,0} parameter(2)
529 %param_3 = f32[672]{0} parameter(3)
530 %param_4 = f32[672]{0} parameter(4)
531 %param_5 = f32[672]{0} parameter(5)
532
533 ROOT %fusion = (f32[672]{0}, f32[672]{0}) fusion(%param_0, %param_1, %param_2, %param_3, %param_4, %param_5), kind=kInput, calls=%fused_computation
534 }
535 )";
536 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
537 EXPECT_TRUE(FusionBitcastLift().Run(module.get()).ValueOrDie());
538 // Remove the old fusion not used anymore.
539 EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
540
541 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
542
543 StatusOr<bool> filecheck_result = RunFileCheck(module->ToString(),
544 R"(
545 ; CHECK-LABEL: %fused_computation
546 ; CHECK-COUNT-8: bitcast(
547 ; CHECK-NOT: bitcast(
548 ; CHECK-LABEL: ENTRY %main
549 ; CHECK-COUNT-2: bitcast(
550 ; CHECK-NOT: bitcast(
551 ; CHECK: fusion(
552 )");
553 EXPECT_TRUE(filecheck_result.status().ok());
554 EXPECT_TRUE(filecheck_result.ValueOrDie());
555 }
556
TEST_F(FusionBitcastLiftTest,LayoutChangeNotSupported)557 TEST_F(FusionBitcastLiftTest, LayoutChangeNotSupported) {
558 const char* hlo_text = R"(
559 HloModule bla
560
561 add {
562 param0 = f32[] parameter(0)
563 param1 = f32[] parameter(1)
564 ROOT add = f32[] add(param0, param1)
565 }
566
567 fused_computation {
568 param_1.11485 = f32[1,1,1536,3072]{3,2,1,0} parameter(1)
569 copy.1383 = f32[1,1,1536,3072]{1,0,2,3} copy(param_1.11485)
570 param_0.7122 = f32[3072]{0} parameter(0)
571 constant.9031 = f32[] constant(0.000651041686)
572 broadcast.9040 = f32[3072]{0} broadcast(constant.9031), dimensions={}
573 multiply.7225 = f32[3072]{0} multiply(param_0.7122, broadcast.9040)
574 broadcast.9039 = f32[1,1,1536,3072]{1,0,2,3} broadcast(multiply.7225), dimensions={3}
575 subtract.940 = f32[1,1,1536,3072]{1,0,2,3} subtract(copy.1383, broadcast.9039)
576 multiply.7224 = f32[1,1,1536,3072]{1,0,2,3} multiply(subtract.940, subtract.940)
577 bitcast.3805 = f32[3072,1536]{1,0} bitcast(multiply.7224)
578 constant.25971 = f32[] constant(0)
579 ROOT reduce.790 = f32[3072]{0} reduce(bitcast.3805, constant.25971), dimensions={1}, to_apply=add
580 }
581
582 ENTRY entry {
583 param_0.0 = f32[3072]{0} parameter(0)
584 param_1.0 = f32[1,1,1536,3072]{3,2,1,0} parameter(1)
585 ROOT fusion = f32[3072]{0} fusion(param_0.0, param_1.0), kind=kInput, calls=fused_computation
586 }
587 )";
588 auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie();
589 EXPECT_FALSE(FusionBitcastLift().Run(module.get()).ValueOrDie());
590 }
591
592 } // namespace
593 } // namespace gpu
594 } // namespace xla
595