• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
20 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
21 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 namespace op = xla::testing::opcode_matchers;
31 
32 using MultiOutputFusionTest = HloTestBase;
33 
34 const char kModulePrefix[] = R"(
35     HloModule test_module
36 
37     scalar_add_computation {
38       scalar_lhs.0 = f32[] parameter(0)
39       scalar_rhs.0 = f32[] parameter(1)
40       ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
41     }
42     scalar_mul_computation {
43       scalar_lhs.1 = f32[] parameter(0)
44       scalar_rhs.1 = f32[] parameter(1)
45       ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
46     })";
47 
CountMultiOutputFusions(const HloModule * module)48 static int64_t CountMultiOutputFusions(const HloModule* module) {
49   int multi_output_fusion_count = 0;
50   for (auto* computation : module->MakeNonfusionComputations()) {
51     for (auto* instr : computation->instructions()) {
52       if (instr->IsMultiOutputFusion()) {
53         multi_output_fusion_count++;
54       }
55     }
56   }
57   return multi_output_fusion_count;
58 }
59 
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingReduceAndReduceFusion)60 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
61   // Fusion with reduce instruction root and a sibling reduce instruction
62   // sharing the same input param.
63   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
64     fused_computation {
65       p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
66       mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
67       const.1 = f32[] parameter(0)
68       ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
69     }
70 
71     ENTRY entry {
72       p0 = f32[] parameter(0)
73       p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
74       const.2 = f32[] constant(1)
75       fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
76       reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
77       ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
78     })"))
79                     .ValueOrDie();
80   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
81   SCOPED_TRACE(module->ToString());
82   const HloInstruction* fusion =
83       module->entry_computation()->root_instruction()->operand(0)->operand(0);
84   ASSERT_TRUE(fusion->IsMultiOutputFusion());
85   EXPECT_THAT(fusion->fused_expression_root(),
86               op::Tuple(op::Reduce(), op::Reduce()));
87 }
88 
TEST_F(MultiOutputFusionTest,MultiOutputFusionDifferentReduceInputShapes)89 TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
90   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
91     fused_computation_1 {
92       p1.1 = f32[6400]{0} parameter(1)
93       mul = f32[6400]{0} multiply(p1.1, p1.1)
94       const.1 = f32[] parameter(0)
95       ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation
96     }
97 
98     fused_computation_2 {
99       p1.2 = f32[6400]{0} parameter(1)
100       r1 = f32[64,100]{0,1} reshape(p1.2)
101       const.2 = f32[] parameter(0)
102       ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
103     }
104 
105     ENTRY entry {
106       p0 = f32[] parameter(0)
107       p1 = f32[6400]{0} parameter(1)
108       fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
109       fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
110       ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
111     })"))
112                     .ValueOrDie();
113   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
114 }
115 
TEST_F(MultiOutputFusionTest,MultiOutputFusionDifferentReduceOutputShapes)116 TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
117   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
118     fused_computation_1 {
119       p1.1 = f32[10,10]{1,0} parameter(1)
120       mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
121       const.1 = f32[] parameter(0)
122       ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation
123     }
124 
125     fused_computation_2 {
126       p1.2 = f32[10,10]{1,0} parameter(1)
127       const.2 = f32[] parameter(0)
128       ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
129     }
130 
131     ENTRY entry {
132       p0 = f32[] parameter(0)
133       p1.3 = f32[10,10]{1,0} parameter(1)
134       fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1
135       p2 = f32[] parameter(2)
136       fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2
137       ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2)
138     })"))
139                     .ValueOrDie();
140   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
141 }
142 
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingReduceFusions)143 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
144   // Two sibling fusions with reduce instruction roots sharing the same input
145   // param.
146   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
147     fused_computation_1 {
148       p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
149       mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
150       const.1 = f32[] parameter(0)
151       ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
152     }
153 
154     fused_computation_2 {
155       p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
156       const.2 = f32[] parameter(0)
157       ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
158     }
159 
160     ENTRY entry {
161       p0 = f32[] parameter(0)
162       p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
163       fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
164       fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
165       ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
166     })"))
167                     .ValueOrDie();
168   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
169   SCOPED_TRACE(module->ToString());
170   const HloInstruction* fusion =
171       module->entry_computation()->root_instruction()->operand(0)->operand(0);
172   ASSERT_TRUE(fusion->IsMultiOutputFusion());
173   EXPECT_THAT(fusion->fused_expression_root(),
174               op::Tuple(op::Reduce(), op::Reduce()));
175 }
176 
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion)177 TEST_F(MultiOutputFusionTest,
178        MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
179   // Multi-output fusion with two reduce instructions root and a sibling reduce
180   // instruction sharing the same input param.
181   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
182     fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
183       const.1 = f32[] constant(1)
184       p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
185       mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
186       reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
187       reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
188       ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
189     }
190 
191     ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
192       p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
193       const = f32[] constant(1)
194       fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
195       get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
196       get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
197       reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
198       ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
199     })"))
200                     .ValueOrDie();
201   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
202   SCOPED_TRACE(module->ToString());
203   const HloInstruction* fusion =
204       module->entry_computation()->root_instruction()->operand(0)->operand(0);
205   ASSERT_TRUE(fusion->IsMultiOutputFusion());
206   EXPECT_THAT(fusion->fused_expression_root(),
207               op::Tuple(op::Reduce(), op::Reduce(), op::Reduce()));
208 }
209 
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingFusionCheckAgainstReduceOperand)210 TEST_F(MultiOutputFusionTest,
211        MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
212   // Verify that if we already have a multi-output fusion that we prefer to pick
213   // a reduce op from its operands for checking shape compatibility.
214   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
215     fused_computation_1 {
216       p1.1 = f32[10,10]{1,0} parameter(1)
217       mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
218       const.1 = f32[] parameter(0)
219       reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation
220       ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1)
221     }
222 
223     fused_computation_2 {
224       p1.2 = f32[10,10]{1,0} parameter(1)
225       const.2 = f32[] parameter(0)
226       ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
227     }
228 
229     ENTRY entry {
230       p0 = f32[] parameter(0)
231       p1 = f32[10,10]{1,0} parameter(1)
232       p2 = f32[] parameter(2)
233       fusion.1 = (f32[10,10], f32[]) fusion(p0, p1), kind=kInput, calls=fused_computation_1
234       get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[]) fusion.1), index=0
235       get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[]) fusion.1), index=1
236       fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2
237       ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2)
238     })"))
239                     .ValueOrDie();
240   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
241 }
242 
TEST_F(MultiOutputFusionTest,MultiOutputFusionTwoLoops)243 TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
244   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
245     fused_computation_1 {
246       p0.1 = f32[6400]{0} parameter(0)
247       ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
248     }
249 
250     fused_computation_2 {
251       p0.2 = f32[6400]{0} parameter(0)
252       const.2 = f32[] constant(1)
253       broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
254       ROOT div = f32[6400]{0} divide(p0.2, broadcast)
255     }
256 
257     ENTRY entry {
258       p0 = f32[6400]{0} parameter(0)
259       fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
260       fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
261       ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
262     })"))
263                     .ValueOrDie();
264   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
265   SCOPED_TRACE(module->ToString());
266   const HloInstruction* fusion =
267       module->entry_computation()->root_instruction()->operand(0)->operand(0);
268   ASSERT_TRUE(fusion->IsMultiOutputFusion());
269   EXPECT_THAT(fusion->fused_expression_root(),
270               op::Tuple(op::Multiply(), op::Divide()));
271 }
272 
TEST_F(MultiOutputFusionTest,MultiOutputFusionLoopReduceToInputFusion)273 TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
274   // Fusing a reduce into a loop fusion would require changing the fusion kind.
275   // That's not supported yet.
276   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
277     fused_computation_1 {
278       p0.1 = f32[6400]{0} parameter(0)
279       ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
280     }
281 
282     ENTRY entry {
283       p0 = f32[6400]{0} parameter(0)
284       fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
285       const.2 = f32[] constant(0)
286       reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation
287       ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce)
288     })"))
289                     .ValueOrDie();
290   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
291 }
292 
TEST_F(MultiOutputFusionTest,MultiOutputFusionLoopElementwise)293 TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
294   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
295     fused_computation_1 {
296       p0.1 = f32[6400]{0} parameter(0)
297       ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
298     }
299 
300     ENTRY entry {
301       p0 = f32[6400]{0} parameter(0)
302       fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
303       const.2 = f32[] constant(1)
304       broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
305       div = f32[6400]{0} divide(p0, broadcast)
306       ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
307     })"))
308                     .ValueOrDie();
309   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
310   SCOPED_TRACE(module->ToString());
311   const HloInstruction* fusion =
312       module->entry_computation()->root_instruction()->operand(0)->operand(0);
313   ASSERT_TRUE(fusion->IsMultiOutputFusion());
314   EXPECT_THAT(fusion->fused_expression_root(),
315               op::Tuple(op::Multiply(), op::Divide()));
316 }
317 
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingLoopsDifferentShapes)318 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
319   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
320     fused_computation_1 {
321       p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
322       ROOT mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
323     }
324 
325     fused_computation_2 {
326       p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
327       const.2 = f32[] constant(0)
328       ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), dimensions={0,3}, to_apply=scalar_add_computation
329     }
330 
331     ENTRY entry {
332       p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
333       fusion.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
334       fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
335       ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0}) tuple(fusion.1, fusion.2)
336     })"))
337                     .ValueOrDie();
338   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
339 }
340 
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingLoopAndMultiOutputLoop)341 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
342   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
343     fused_computation_1 {
344       p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
345       mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
346       exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
347       ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
348         f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
349     }
350 
351     fused_computation_2 {
352       p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
353       const.2 = f32[] constant(0)
354       broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2),
355         dimensions={}
356       ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
357     }
358 
359     ENTRY entry {
360       p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
361       fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
362         f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
363         calls=fused_computation_1
364       fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop,
365         calls=fused_computation_2
366       gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
367       gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
368       ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
369         f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0})
370         tuple(gte0, gte1, fusion.2)
371     })"))
372                     .ValueOrDie();
373   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
374   SCOPED_TRACE(module->ToString());
375   const HloInstruction* fusion =
376       module->entry_computation()->root_instruction()->operand(0)->operand(0);
377   ASSERT_TRUE(fusion->IsMultiOutputFusion());
378   EXPECT_THAT(fusion->fused_expression_root(),
379               op::Tuple(op::Multiply(), op::Exp(), op::Add()));
380 }
381 
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes)382 TEST_F(MultiOutputFusionTest,
383        MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
384   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
385     fused_computation_1 {
386       p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
387       mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
388       exp = f32[8,1,5,16,1,2]{5,4,3,2,1,0} exponential(p0.1)
389       ROOT tuple = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
390         f32[8,1,5,16,1,2]{5,4,3,2,1,0}) tuple(mul, exp)
391     }
392 
393     fused_computation_2 {
394       p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
395       const.2 = f32[] constant(0)
396       ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2),
397         dimensions={0,3}, to_apply=scalar_add_computation
398     }
399 
400     ENTRY entry {
401       p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
402       fusion.1 = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
403         f32[8,1,5,16,1,2]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
404         calls=fused_computation_1
405       fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop,
406         calls=fused_computation_2
407       gte0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
408       gte1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
409       ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
410         f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0})
411         tuple(gte0, gte1, fusion.2)
412     })"))
413                     .ValueOrDie();
414   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
415 }
416 
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionElementwiseAndReduce)417 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
418   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
419     ENTRY reduce {
420       p0 = f32[32,32,32]{2,1,0} parameter(0)
421       c0 = f32[] constant(0)
422       exp = f32[32,32,32]{2,1,0} exponential(p0)
423       reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
424         to_apply=scalar_add_computation
425       ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
426     })"))
427                     .ValueOrDie();
428   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
429   SCOPED_TRACE(module->ToString());
430   const HloInstruction* root = module->entry_computation()->root_instruction();
431   EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
432   const HloInstruction* fusion = root->operand(0)->operand(0);
433   ASSERT_TRUE(fusion->IsMultiOutputFusion());
434   EXPECT_THAT(fusion->fused_expression_root(),
435               op::Tuple(op::Reduce(), op::Exp()));
436 }
437 
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionLoopFusionAndReduce)438 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
439   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
440     fused_add {
441       p0.1 = f32[32,32,32]{2,1,0} parameter(0)
442       p1.1 = f32[32,32,32]{2,1,0} parameter(1)
443       ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
444     }
445 
446     ENTRY reduce {
447       p0 = f32[32,32,32]{2,1,0} parameter(0)
448       p1 = f32[32,32,32]{2,1,0} parameter(1)
449       c0 = f32[] constant(0)
450       add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
451       reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
452         to_apply=scalar_add_computation
453       ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
454     })"))
455                     .ValueOrDie();
456   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
457   SCOPED_TRACE(module->ToString());
458   const HloInstruction* root = module->entry_computation()->root_instruction();
459   EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
460   const HloInstruction* fusion = root->operand(0)->operand(0);
461   ASSERT_TRUE(fusion->IsMultiOutputFusion());
462   EXPECT_THAT(fusion->fused_expression_root(),
463               op::Tuple(op::Reduce(), op::Add()));
464 }
465 
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionLoopFusionAndReduceFusion)466 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
467   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
468     fused_select {
469       p1.1 = f32[32,32,32]{2,1,0} parameter(1)
470       c0 = f32[] constant(0)
471       broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
472       greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
473         f32[32,32,32]{2,1,0} broadcast), direction=GT
474       p0.1 = f32[32,32,32]{2,1,0} parameter(0)
475       ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
476         greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
477     }
478 
479     fused_reduce {
480       p0.2 = f32[32,32,32]{2,1,0} parameter(0)
481       c1 = f32[] constant(0)
482       r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
483         to_apply=scalar_add_computation
484       mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
485       r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
486         to_apply=scalar_add_computation
487       ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
488     }
489 
490     ENTRY reduce {
491       p0 = f32[32,32,32]{2,1,0} parameter(0)
492       p1 = f32[32,32,32]{2,1,0} parameter(1)
493       select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
494       fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
495         calls=fused_reduce
496       gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
497       gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
498       ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
499         tuple(gte1, gte1, select)
500     })"))
501                     .ValueOrDie();
502   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
503   SCOPED_TRACE(module->ToString());
504   const HloInstruction* root = module->entry_computation()->root_instruction();
505   EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
506                               op::GetTupleElement()));
507   const HloInstruction* fusion = root->operand(0)->operand(0);
508   ASSERT_TRUE(fusion->IsMultiOutputFusion());
509   EXPECT_THAT(fusion->fused_expression_root(),
510               op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
511 }
512 
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionDoNotFuseLoopReduceFusion)513 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
514   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
515     fused_element_wise {
516       p0.1 = f32[2,2,2]{2,1,0} parameter(0)
517       p1.1 = f32[2,2,2]{2,1,0} parameter(1)
518       ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
519     }
520 
521     fused_reduce {
522       p0.2 = f32[2,2,2]{2,1,0} parameter(0)
523       mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
524         f32[2,2,2]{2,1,0} p0.2)
525       broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
526       c1 = f32[] constant(0)
527       ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
528         f32[] c1), dimensions={1,3}, to_apply=scalar_add_computation
529     }
530 
531     ENTRY reduce {
532       p0 = f32[2,2,2]{2,1,0} parameter(0)
533       p1 = f32[2,2,2]{2,1,0} parameter(1)
534       element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
535       fusion = f32[2,2]{1,0} fusion(element_wise), kind=kLoop, calls=fused_reduce
536       ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
537     })"))
538                     .ValueOrDie();
539   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
540 }
541 
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionFp16LoopFusionAndReduceFusion)542 TEST_F(MultiOutputFusionTest,
543        ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
544   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
545     fused_select {
546       p1.1 = f16[32,32,32]{2,1,0} parameter(1)
547       c0 = f16[] constant(0)
548       broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
549       greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
550         f16[32,32,32]{2,1,0} broadcast), direction=GT
551       p0.1 = f16[32,32,32]{2,1,0} parameter(0)
552       ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
553         greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
554     }
555     fused_reduce {
556       p0.2 = f16[32,32,32]{2,1,0} parameter(0)
557       convert = f32[32,32,32]{2,1,0} convert(p0.2)
558       c1 = f32[] constant(0)
559       r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
560         to_apply=scalar_add_computation
561       mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
562       r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
563         to_apply=scalar_add_computation
564       ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
565     }
566     ENTRY reduce {
567       p0 = f16[32,32,32]{2,1,0} parameter(0)
568       p1 = f16[32,32,32]{2,1,0} parameter(1)
569       select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
570       fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
571         calls=fused_reduce
572       gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
573       gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
574       ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
575         tuple(gte1, gte1, select)
576     })"))
577                     .ValueOrDie();
578   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
579   SCOPED_TRACE(module->ToString());
580   const HloInstruction* root = module->entry_computation()->root_instruction();
581   EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
582                               op::GetTupleElement()));
583   const HloInstruction* fusion = root->operand(0)->operand(0);
584   ASSERT_TRUE(fusion->IsMultiOutputFusion());
585   EXPECT_THAT(fusion->fused_expression_root(),
586               op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
587 }
588 
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionReduceUnfriendlyLoopFusion)589 TEST_F(MultiOutputFusionTest,
590        ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
591   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
592     mixed_input_layouts_computation {
593       p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
594       p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
595       copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
596       c0 = f16[] constant(0)
597       broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
598       greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
599       ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
600     }
601     fused_reduce {
602       p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
603       convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
604       c0.2 = f32[] constant(0)
605       ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
606     }
607     ENTRY reduce {
608       p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
609       p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
610       loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
611       reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
612       ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
613     })"))
614                     .ValueOrDie();
615   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
616 }
617 
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionAvoidsCycles)618 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionAvoidsCycles) {
619   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
620     fused_add {
621       p0 = f32[32,32,32]{2,1,0} parameter(0)
622       p1 = f32[32,32,32]{2,1,0} parameter(1)
623       ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
624     }
625 
626     fused_mul {
627       p2 = f32[64,64,64]{2,1,0} parameter(0)
628       p3 = f32[64,64,64]{2,1,0} parameter(1)
629       ROOT multiply = f32[64,64,64]{2,1,0} multiply(p2, p3)
630     }
631 
632     fused_reduce_1 {
633       p4 = f32[32,32,32]{2,1,0} parameter(0)
634       p5 = f32[64,64,64]{2,1,0} parameter(1)
635       slice = f32[32,32,32]{2,1,0} slice(p5), slice={[0:32], [0:32], [0:32]}
636       add = f32[32,32,32]{2,1,0} add(p4, slice)
637       c0 = f32[] constant(0)
638       ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
639         to_apply=scalar_add_computation
640     }
641 
642     fused_reduce_2 {
643       p6 = f32[32,32,32]{2,1,0} parameter(0)
644       p7 = f32[64,64,64]{2,1,0} parameter(1)
645       c0 = f32[] constant(0)
646       pad = f32[64,64,64]{2,1,0} pad(p6, c0), padding=16_16x16_16x16_16
647       mul = f32[64,64,64]{2,1,0} multiply(pad, p7)
648       ROOT r1 = f32[64,64]{1,0} reduce(mul, c0), dimensions={2},
649         to_apply=scalar_add_computation
650     }
651 
652     ENTRY reduce {
653       p8 = f32[32,32,32]{2,1,0} parameter(0)
654       p9 = f32[64,64,64]{2,1,0} parameter(1)
655       // `add` and `mul` can be multi-output fused with `reduce1` and `reduce2`,
656       // respectively. However, both isn't possible, because multi-output fusion
657       // will introduce an extra dependency from `neg` to `abs` or vice versa.
658       // Hence, the second multi-output fusion would introduce a cycle.
659       add = f32[32,32,32]{2,1,0} fusion(p8, p8), kind=kLoop, calls=fused_add
660       mul = f32[64,64,64]{2,1,0} fusion(p9, p9), kind=kLoop, calls=fused_mul
661 
662       reduce1 = f32[32,32]{1,0} fusion(add, mul), kind=kInput,
663           calls=fused_reduce_1
664       reduce2 = f32[64,64]{1,0} fusion(add, mul), kind=kInput,
665           calls=fused_reduce_2
666       ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[64,64]{1,0},
667                    f32[64,64,64]{2,1,0}) tuple(add, reduce1, reduce2, mul)
668     })"))
669                     .ValueOrDie();
670   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
671   SCOPED_TRACE(module->ToString());
672   EXPECT_EQ(1, CountMultiOutputFusions(module.get()));
673 }
674 
TEST_F(MultiOutputFusionTest,PreferFuseProducerIntoFusionConsumer)675 TEST_F(MultiOutputFusionTest, PreferFuseProducerIntoFusionConsumer) {
676   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
677     fused_add {
678       p0 = f32[32,32,32]{2,1,0} parameter(0)
679       p1 = f32[32,32,32]{2,1,0} parameter(1)
680       ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
681     }
682     fused_reduce {
683       p0 = f32[32,32,32]{2,1,0} parameter(0)
684       p1 = f32[64,64,64]{2,1,0} parameter(1)
685       slice = f32[32,32,32]{2,1,0} slice(p1), slice={[0:32], [0:32], [0:32]}
686       add = f32[32,32,32]{2,1,0} add(p0, slice)
687       c0 = f32[] constant(0)
688       ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
689         to_apply=scalar_add_computation
690     }
691     ENTRY reduce {
692       p0 = f32[32,32,32]{2,1,0} parameter(0)
693       p1 = f32[64,64,64]{2,1,0} parameter(1)
694       add = f32[32,32,32]{2,1,0} fusion(p0, p0), kind=kLoop, calls=fused_add
695       c0 = f32[] constant(0)
696       reduce2 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
697         to_apply=scalar_add_computation
698       reduce = f32[32,32]{1,0} fusion(add, p1), kind=kInput, calls=fused_reduce
699       ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[32,32]{1,0})
700                   tuple(add, reduce, reduce2)
701     })"))
702                     .ValueOrDie();
703   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
704   SCOPED_TRACE(module->ToString());
705   int multi_output_fusion_count = 0;
706   for (auto* computation : module->MakeNonfusionComputations()) {
707     for (auto* instr : computation->instructions()) {
708       if (instr->IsMultiOutputFusion()) {
709         multi_output_fusion_count++;
710       }
711     }
712   }
713   EXPECT_EQ(1, multi_output_fusion_count);
714 }
715 
716 // Check that we limit the number of operands to fusions we create.
TEST_F(MultiOutputFusionTest,AvoidsLargeFusion)717 TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) {
718   constexpr int64_t kNumParams = 200;
719   ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
720 
721   // Compute
722   //   p0 * p1,
723   //   p0 * p1 + p1 * p2
724   //   p0 * p1 + p1 * p2 + p2 * p3
725   //   ...
726   // where each of the (pi * pj)'s is represented as a fusion node so that
727   // multi-output fusion will pay attention to it.
728   auto module = CreateNewVerifiedModule();
729   HloComputation::Builder b(TestName());
730   Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
731 
732   std::vector<HloInstruction*> params;
733   for (int64_t i = 0; i < kNumParams; ++i) {
734     params.push_back(
735         b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
736   }
737 
738   // Creates a fusion node that calculates x*y.
739   auto make_fusion = [&](HloInstruction* x, HloInstruction* y) {
740     HloComputation::Builder sub_builder("subcomp");
741     auto* p0 = sub_builder.AddInstruction(
742         HloInstruction::CreateParameter(0, shape, "p"));
743     auto* p1 = sub_builder.AddInstruction(
744         HloInstruction::CreateParameter(1, shape, "p"));
745     sub_builder.AddInstruction(
746         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
747     HloComputation* subcomp =
748         module->AddEmbeddedComputation(sub_builder.Build());
749     return HloInstruction::CreateFusion(
750         shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp);
751   };
752 
753   auto* sum = b.AddInstruction(make_fusion(params[0], params[1]));
754   for (int64_t i = 2; i < kNumParams; ++i) {
755     sum = b.AddInstruction(HloInstruction::CreateBinary(
756         shape, HloOpcode::kAdd, sum,
757         b.AddInstruction(make_fusion(params[i - 1], params[i]))));
758   }
759   auto computation = module->AddEntryComputation(b.Build());
760   EXPECT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
761   SCOPED_TRACE(module->ToString());
762   for (const HloInstruction* instr : computation->instructions()) {
763     EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()),
764               MaxOperandsAndOutputsPerFusion())
765         << instr->ToString();
766   }
767 }
768 
TEST_F(MultiOutputFusionTest,MultiOutputFusionDUS)769 TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
770   auto module = ParseAndReturnVerifiedModule(R"(HloModule dus_mof
771     fusion.1 {
772       p.0 = f16[50,96,1024]{2,1,0} parameter(0)
773       p.1 = f16[1,96,1024]{2,1,0} parameter(1)
774       c.0 = s32[3]{0} constant({0, 0, 0})
775       ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
776     }
777 
778     fusion.2 {
779       p.0 = f16[50,96,1024]{2,1,0} parameter(0)
780       p.1 = f16[1,96,1024]{2,1,0} parameter(1)
781       c.0 = s32[3]{0} constant({0, 0, 0})
782       ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
783     }
784 
785     ENTRY entry {
786       p.00 = f16[50,96,1024]{2,1,0} parameter(0)
787       p.01 = f16[50,96,1024]{2,1,0} parameter(1)
788       p.1 = f16[1,96,1024]{2,1,0} parameter(2)
789 
790       f1 = f16[50,96,1024] fusion(p.00, p.1), kind=kLoop, calls=fusion.1
791       f2 = f16[50,96,1024] fusion(p.01, p.1), kind=kLoop, calls=fusion.2
792       ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2)
793     })")
794                     .ValueOrDie();
795   ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
796 }
797 
798 // Check that we don't fuse too many reductions together.
TEST_F(MultiOutputFusionTest,SharedMemoryBudget)799 TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
800   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
801     fused_computation0 {
802       p0 = f32[64,64] parameter(0)
803       p1 = f32[64,64] parameter(1)
804       p2 = f32[] parameter(2)
805       add = f32[64,64] add(p0, p1)
806       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
807         to_apply=scalar_add_computation
808     }
809     fused_computation1 {
810       p0 = f32[64,64] parameter(0)
811       p1 = f32[64,64] parameter(1)
812       p2 = f32[] parameter(2)
813       add = f32[64,64] add(p0, p1)
814       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
815         to_apply=scalar_add_computation
816     }
817     fused_computation2 {
818       p0 = f32[64,64] parameter(0)
819       p1 = f32[64,64] parameter(1)
820       p2 = f32[] parameter(2)
821       add = f32[64,64] add(p0, p1)
822       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
823         to_apply=scalar_add_computation
824     }
825     fused_computation3 {
826       p0 = f32[64,64] parameter(0)
827       p1 = f32[64,64] parameter(1)
828       p2 = f32[] parameter(2)
829       add = f32[64,64] add(p0, p1)
830       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
831         to_apply=scalar_add_computation
832     }
833     fused_computation4 {
834       p0 = f32[64,64] parameter(0)
835       p1 = f32[64,64] parameter(1)
836       p2 = f32[] parameter(2)
837       add = f32[64,64] add(p0, p1)
838       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
839         to_apply=scalar_add_computation
840     }
841     fused_computation5 {
842       p0 = f32[64,64] parameter(0)
843       p1 = f32[64,64] parameter(1)
844       p2 = f32[] parameter(2)
845       add = f32[64,64] add(p0, p1)
846       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
847         to_apply=scalar_add_computation
848     }
849     fused_computation6 {
850       p0 = f32[64,64] parameter(0)
851       p1 = f32[64,64] parameter(1)
852       p2 = f32[] parameter(2)
853       add = f32[64,64] add(p0, p1)
854       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
855         to_apply=scalar_add_computation
856     }
857     fused_computation7 {
858       p0 = f32[64,64] parameter(0)
859       p1 = f32[64,64] parameter(1)
860       p2 = f32[] parameter(2)
861       add = f32[64,64] add(p0, p1)
862       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
863         to_apply=scalar_add_computation
864     }
865     fused_computation8 {
866       p0 = f32[64,64] parameter(0)
867       p1 = f32[64,64] parameter(1)
868       p2 = f32[] parameter(2)
869       add = f32[64,64] add(p0, p1)
870       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
871         to_apply=scalar_add_computation
872     }
873     fused_computation9 {
874       p0 = f32[64,64] parameter(0)
875       p1 = f32[64,64] parameter(1)
876       p2 = f32[] parameter(2)
877       add = f32[64,64] add(p0, p1)
878       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
879         to_apply=scalar_add_computation
880     }
881     ENTRY computation {
882       zero = f32[] constant(0)
883       param0 = f32[64,64] parameter(0)
884       param1 = f32[64,64] parameter(1)
885       param2 = f32[64,64] parameter(2)
886       param3 = f32[64,64] parameter(3)
887       param4 = f32[64,64] parameter(4)
888       param5 = f32[64,64] parameter(5)
889       param6 = f32[64,64] parameter(6)
890       param7 = f32[64,64] parameter(7)
891       param8 = f32[64,64] parameter(8)
892       param9 = f32[64,64] parameter(9)
893       out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
894       out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
895       out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
896       out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
897       out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
898       out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
899       out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
900       out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
901       out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
902       out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
903       ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
904     }
905   )"))
906                     .ValueOrDie();
907   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).value());
908 
909   EXPECT_EQ(3, CountMultiOutputFusions(module.get()));
910 }
911 
TEST_F(MultiOutputFusionTest,DoNotGroupTooManyReductions)912 TEST_F(MultiOutputFusionTest, DoNotGroupTooManyReductions) {
913   auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
914     fused_computation0 {
915       p0 = f32[64,64] parameter(0)
916       p1 = f32[64,64] parameter(1)
917       p2 = f32[] parameter(2)
918       add = f32[64,64] add(p0, p1)
919       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
920         to_apply=scalar_add_computation
921     }
922     fused_computation1 {
923       p0 = f32[64,64] parameter(0)
924       p1 = f32[64,64] parameter(1)
925       p2 = f32[] parameter(2)
926       add = f32[64,64] add(p0, p1)
927       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
928         to_apply=scalar_add_computation
929     }
930     fused_computation2 {
931       p0 = f32[64,64] parameter(0)
932       p1 = f32[64,64] parameter(1)
933       p2 = f32[] parameter(2)
934       add = f32[64,64] add(p0, p1)
935       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
936         to_apply=scalar_add_computation
937     }
938     fused_computation3 {
939       p0 = f32[64,64] parameter(0)
940       p1 = f32[64,64] parameter(1)
941       p2 = f32[] parameter(2)
942       add = f32[64,64] add(p0, p1)
943       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
944         to_apply=scalar_add_computation
945     }
946     fused_computation4 {
947       p0 = f32[64,64] parameter(0)
948       p1 = f32[64,64] parameter(1)
949       p2 = f32[] parameter(2)
950       add = f32[64,64] add(p0, p1)
951       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
952         to_apply=scalar_add_computation
953     }
954     fused_computation5 {
955       p0 = f32[64,64] parameter(0)
956       p1 = f32[64,64] parameter(1)
957       p2 = f32[] parameter(2)
958       add = f32[64,64] add(p0, p1)
959       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
960         to_apply=scalar_add_computation
961     }
962     fused_computation6 {
963       p0 = f32[64,64] parameter(0)
964       p1 = f32[64,64] parameter(1)
965       p2 = f32[] parameter(2)
966       add = f32[64,64] add(p0, p1)
967       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
968         to_apply=scalar_add_computation
969     }
970     fused_computation7 {
971       p0 = f32[64,64] parameter(0)
972       p1 = f32[64,64] parameter(1)
973       p2 = f32[] parameter(2)
974       add = f32[64,64] add(p0, p1)
975       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
976         to_apply=scalar_add_computation
977     }
978     fused_computation8 {
979       p0 = f32[64,64] parameter(0)
980       p1 = f32[64,64] parameter(1)
981       p2 = f32[] parameter(2)
982       add = f32[64,64] add(p0, p1)
983       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
984         to_apply=scalar_add_computation
985     }
986     fused_computation9 {
987       p0 = f32[64,64] parameter(0)
988       p1 = f32[64,64] parameter(1)
989       p2 = f32[] parameter(2)
990       add = f32[64,64] add(p0, p1)
991       ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
992         to_apply=scalar_add_computation
993     }
994     ENTRY computation {
995       zero = f32[] constant(0)
996       param0 = f32[64,64] parameter(0)
997       param1 = f32[64,64] parameter(1)
998       param2 = f32[64,64] parameter(2)
999       param3 = f32[64,64] parameter(3)
1000       param4 = f32[64,64] parameter(4)
1001       param5 = f32[64,64] parameter(5)
1002       param6 = f32[64,64] parameter(6)
1003       param7 = f32[64,64] parameter(7)
1004       param8 = f32[64,64] parameter(8)
1005       param9 = f32[64,64] parameter(9)
1006       out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
1007       out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
1008       out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
1009       out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
1010       out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
1011       out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
1012       out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
1013       out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
1014       out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
1015       out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
1016       ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
1017     }
1018   )"))
1019                     .ValueOrDie();
1020   ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).value());
1021 
1022   EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
1023 }
1024 
TEST_F(MultiOutputFusionTest,NoFusionToAvoidUsingTooMuchSharedMemory)1025 TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
1026   auto module = ParseAndReturnVerifiedModule(R"(
1027   HloModule xla_computation_update_step.10931
1028 
1029 %scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
1030   %scalar_lhs.1 = f64[] parameter(0)
1031   %scalar_rhs.1 = f64[] parameter(1)
1032   ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
1033 }
1034 
1035 %fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
1036   %param_0.8 = f64[64,64]{1,0} parameter(0)
1037   %param_1.11 = f64[64,64]{1,0} parameter(1)
1038   %multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
1039   %constant_5217.3 = f64[] constant(0)
1040   %broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
1041   %multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
1042   %reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
1043   %param_2.9 = f64[64,64]{1,0} parameter(2)
1044   %multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
1045   %constant_5217.1.clone.1 = f64[] constant(0)
1046   %broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
1047   %multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
1048   %reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
1049   ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
1050 }
1051 
1052 %primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
1053   %parameter.6427 = f64[] parameter(0)
1054   %parameter.6428 = f64[] parameter(1)
1055   ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
1056 }
1057 
1058 %fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
1059   %param_0.7 = f64[64,64]{1,0} parameter(0)
1060   %param_1.9 = f64[64,64]{1,0} parameter(1)
1061   %multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
1062   %constant_5217.2 = f64[] constant(0)
1063   ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
1064 }
1065 
1066 ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
1067   %param_0.1090 = f64[64,64]{1,0} parameter(0)
1068   %param_1.1377 = f64[64,64]{1,0} parameter(1)
1069   %param_2.1948 = f64[64,64]{1,0} parameter(2)
1070   %fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
1071   %get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
1072   %fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
1073   %get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
1074   ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
1075 }
1076   )")
1077                     .ValueOrDie();
1078   EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).value());
1079 }
1080 
TEST_F(MultiOutputFusionTest,NoFusionToAvoidCodeDuplication)1081 TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
1082   auto module = ParseAndReturnVerifiedModule(R"(
1083 HloModule module
1084 
1085 and.reduce_sub_computation {
1086   x = pred[] parameter(0)
1087   y = pred[] parameter(1)
1088   ROOT and = pred[] and(x, y)
1089 }
1090 
1091 fused_computation.1 {
1092   param_4.658 = f32[2,20,256]{2,0,1} parameter(4)
1093   slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]}
1094   constant.6847 = s32[] constant(0)
1095   broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={}
1096   param_9.415 = s32[3]{0} parameter(9)
1097   compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE
1098   constant.6846 = pred[] constant(true)
1099   reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation
1100   broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={}
1101   param_5.528 = f32[2,512]{1,0} parameter(5)
1102   slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]}
1103   bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384)
1104   constant.5418 = f32[] constant(0)
1105   broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={}
1106   select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227)
1107   add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173)
1108   param_0.299 = s32[] parameter(0)
1109   constant.5157 = s32[] constant(11)
1110   dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299)
1111   slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]}
1112   constant.6800 = s32[] constant(0)
1113   broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={}
1114   param_8.484 = s32[3]{0} parameter(8)
1115   compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE
1116   constant.6798 = pred[] constant(true)
1117   reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation
1118   broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={}
1119   param_3.1169 = f32[2,512]{1,0} parameter(3)
1120   slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]}
1121   bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382)
1122   select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227)
1123   add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172)
1124   constant.5154 = s32[] constant(10)
1125   dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299)
1126   slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]}
1127   constant.6794 = s32[] constant(0)
1128   broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={}
1129   param_7.478 = s32[3]{0} parameter(7)
1130   compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE
1131   constant.6793 = pred[] constant(true)
1132   reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation
1133   broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={}
1134   param_2.1685 = f32[2,512]{1,0} parameter(2)
1135   slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]}
1136   bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380)
1137   select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227)
1138   add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171)
1139   constant.5153 = s32[] constant(9)
1140   dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299)
1141   slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]}
1142   constant.6788 = s32[] constant(0)
1143   broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={}
1144   param_6.495 = s32[3]{0} parameter(6)
1145   compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE
1146   constant.6786 = pred[] constant(true)
1147   reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation
1148   broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={}
1149   param_1.1408 = f32[2,512]{1,0} parameter(1)
1150   slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]}
1151   bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378)
1152   select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227)
1153   add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170)
1154   constant.5152 = s32[] constant(8)
1155   ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299)
1156 }
1157 
1158 fused_computation.2 {
1159   param_4.655 = f32[2,20,256]{2,0,1} parameter(4)
1160   slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]}
1161   param_6.483 = pred[] parameter(6)
1162   broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={}
1163   param_5.525 = f32[2,512]{1,0} parameter(5)
1164   slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]}
1165   bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368)
1166   constant.5415 = f32[] constant(0)
1167   broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={}
1168   select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225)
1169   add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161)
1170   param_0.265 = s32[] parameter(0)
1171   constant.5151 = s32[] constant(7)
1172   dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265)
1173   slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]}
1174   constant.6782 = s32[] constant(0)
1175   broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={}
1176   param_9.391 = s32[3]{0} parameter(9)
1177   compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE
1178   constant.6781 = pred[] constant(true)
1179   reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation
1180   broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={}
1181   param_3.1167 = f32[2,512]{1,0} parameter(3)
1182   slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]}
1183   bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366)
1184   select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225)
1185   add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160)
1186   constant.5150 = s32[] constant(6)
1187   dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265)
1188   slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]}
1189   constant.6776 = s32[] constant(0)
1190   broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={}
1191   param_8.464 = s32[3]{0} parameter(8)
1192   compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE
1193   constant.6775 = pred[] constant(true)
1194   reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation
1195   broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={}
1196   param_2.1684 = f32[2,512]{1,0} parameter(2)
1197   slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]}
1198   bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364)
1199   select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225)
1200   add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159)
1201   constant.5149 = s32[] constant(5)
1202   dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265)
1203   slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]}
1204   constant.6770 = s32[] constant(0)
1205   broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={}
1206   param_7.458 = s32[3]{0} parameter(7)
1207   compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE
1208   constant.6769 = pred[] constant(true)
1209   reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation
1210   broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={}
1211   param_1.1405 = f32[2,512]{1,0} parameter(1)
1212   slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]}
1213   bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362)
1214   select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225)
1215   add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158)
1216   constant.5148 = s32[] constant(4)
1217   ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265)
1218 }
1219 
1220 ENTRY main {
1221   param_0.0 = s32[] parameter(0)
1222   param_1.0 = f32[2,512]{1,0} parameter(1)
1223   param_2.0 = f32[2,512]{1,0} parameter(2)
1224   param_3.0 = f32[2,512]{1,0} parameter(3)
1225   param_4.0 = f32[2,20,256]{2,1,0} parameter(4)
1226   param_5.0 = f32[2,512]{1,0} parameter(5)
1227   param_6.0 = s32[3]{0} parameter(6)
1228   param_7.0 = s32[3]{0} parameter(7)
1229   param_8.0 = s32[3]{0} parameter(8)
1230   param_9.0 = s32[3]{0} parameter(9)
1231   fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1
1232   param_10 = pred[] parameter(10)
1233   fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2
1234   ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2)
1235 }
1236   )")
1237                     .ValueOrDie();
1238   EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
1239 }
1240 
1241 }  // namespace gpu
1242 }  // namespace xla
1243