1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
20 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
21 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28 namespace gpu {
29
30 namespace op = xla::testing::opcode_matchers;
31
32 using MultiOutputFusionTest = HloTestBase;
33
34 const char kModulePrefix[] = R"(
35 HloModule test_module
36
37 scalar_add_computation {
38 scalar_lhs.0 = f32[] parameter(0)
39 scalar_rhs.0 = f32[] parameter(1)
40 ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
41 }
42 scalar_mul_computation {
43 scalar_lhs.1 = f32[] parameter(0)
44 scalar_rhs.1 = f32[] parameter(1)
45 ROOT mul.1 = f32[] multiply(scalar_lhs.1, scalar_rhs.1)
46 })";
47
CountMultiOutputFusions(const HloModule * module)48 static int64_t CountMultiOutputFusions(const HloModule* module) {
49 int multi_output_fusion_count = 0;
50 for (auto* computation : module->MakeNonfusionComputations()) {
51 for (auto* instr : computation->instructions()) {
52 if (instr->IsMultiOutputFusion()) {
53 multi_output_fusion_count++;
54 }
55 }
56 }
57 return multi_output_fusion_count;
58 }
59
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingReduceAndReduceFusion)60 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
61 // Fusion with reduce instruction root and a sibling reduce instruction
62 // sharing the same input param.
63 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
64 fused_computation {
65 p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
66 mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
67 const.1 = f32[] parameter(0)
68 ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
69 }
70
71 ENTRY entry {
72 p0 = f32[] parameter(0)
73 p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
74 const.2 = f32[] constant(1)
75 fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
76 reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
77 ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
78 })"))
79 .ValueOrDie();
80 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
81 SCOPED_TRACE(module->ToString());
82 const HloInstruction* fusion =
83 module->entry_computation()->root_instruction()->operand(0)->operand(0);
84 ASSERT_TRUE(fusion->IsMultiOutputFusion());
85 EXPECT_THAT(fusion->fused_expression_root(),
86 op::Tuple(op::Reduce(), op::Reduce()));
87 }
88
TEST_F(MultiOutputFusionTest,MultiOutputFusionDifferentReduceInputShapes)89 TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
90 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
91 fused_computation_1 {
92 p1.1 = f32[6400]{0} parameter(1)
93 mul = f32[6400]{0} multiply(p1.1, p1.1)
94 const.1 = f32[] parameter(0)
95 ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation
96 }
97
98 fused_computation_2 {
99 p1.2 = f32[6400]{0} parameter(1)
100 r1 = f32[64,100]{0,1} reshape(p1.2)
101 const.2 = f32[] parameter(0)
102 ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
103 }
104
105 ENTRY entry {
106 p0 = f32[] parameter(0)
107 p1 = f32[6400]{0} parameter(1)
108 fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
109 fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
110 ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
111 })"))
112 .ValueOrDie();
113 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
114 }
115
TEST_F(MultiOutputFusionTest,MultiOutputFusionDifferentReduceOutputShapes)116 TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
117 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
118 fused_computation_1 {
119 p1.1 = f32[10,10]{1,0} parameter(1)
120 mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
121 const.1 = f32[] parameter(0)
122 ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation
123 }
124
125 fused_computation_2 {
126 p1.2 = f32[10,10]{1,0} parameter(1)
127 const.2 = f32[] parameter(0)
128 ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
129 }
130
131 ENTRY entry {
132 p0 = f32[] parameter(0)
133 p1.3 = f32[10,10]{1,0} parameter(1)
134 fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1
135 p2 = f32[] parameter(2)
136 fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2
137 ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2)
138 })"))
139 .ValueOrDie();
140 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
141 }
142
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingReduceFusions)143 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) {
144 // Two sibling fusions with reduce instruction roots sharing the same input
145 // param.
146 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
147 fused_computation_1 {
148 p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
149 mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
150 const.1 = f32[] parameter(0)
151 ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
152 }
153
154 fused_computation_2 {
155 p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
156 const.2 = f32[] parameter(0)
157 ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
158 }
159
160 ENTRY entry {
161 p0 = f32[] parameter(0)
162 p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
163 fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
164 fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
165 ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
166 })"))
167 .ValueOrDie();
168 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
169 SCOPED_TRACE(module->ToString());
170 const HloInstruction* fusion =
171 module->entry_computation()->root_instruction()->operand(0)->operand(0);
172 ASSERT_TRUE(fusion->IsMultiOutputFusion());
173 EXPECT_THAT(fusion->fused_expression_root(),
174 op::Tuple(op::Reduce(), op::Reduce()));
175 }
176
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion)177 TEST_F(MultiOutputFusionTest,
178 MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
179 // Multi-output fusion with two reduce instructions root and a sibling reduce
180 // instruction sharing the same input param.
181 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
182 fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
183 const.1 = f32[] constant(1)
184 p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
185 mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
186 reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
187 reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
188 ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
189 }
190
191 ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
192 p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
193 const = f32[] constant(1)
194 fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
195 get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
196 get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
197 reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
198 ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
199 })"))
200 .ValueOrDie();
201 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
202 SCOPED_TRACE(module->ToString());
203 const HloInstruction* fusion =
204 module->entry_computation()->root_instruction()->operand(0)->operand(0);
205 ASSERT_TRUE(fusion->IsMultiOutputFusion());
206 EXPECT_THAT(fusion->fused_expression_root(),
207 op::Tuple(op::Reduce(), op::Reduce(), op::Reduce()));
208 }
209
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingFusionCheckAgainstReduceOperand)210 TEST_F(MultiOutputFusionTest,
211 MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
212 // Verify that if we already have a multi-output fusion that we prefer to pick
213 // a reduce op from its operands for checking shape compatibility.
214 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
215 fused_computation_1 {
216 p1.1 = f32[10,10]{1,0} parameter(1)
217 mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
218 const.1 = f32[] parameter(0)
219 reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation
220 ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1)
221 }
222
223 fused_computation_2 {
224 p1.2 = f32[10,10]{1,0} parameter(1)
225 const.2 = f32[] parameter(0)
226 ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
227 }
228
229 ENTRY entry {
230 p0 = f32[] parameter(0)
231 p1 = f32[10,10]{1,0} parameter(1)
232 p2 = f32[] parameter(2)
233 fusion.1 = (f32[10,10], f32[]) fusion(p0, p1), kind=kInput, calls=fused_computation_1
234 get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[]) fusion.1), index=0
235 get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[]) fusion.1), index=1
236 fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2
237 ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2)
238 })"))
239 .ValueOrDie();
240 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
241 }
242
TEST_F(MultiOutputFusionTest,MultiOutputFusionTwoLoops)243 TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) {
244 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
245 fused_computation_1 {
246 p0.1 = f32[6400]{0} parameter(0)
247 ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
248 }
249
250 fused_computation_2 {
251 p0.2 = f32[6400]{0} parameter(0)
252 const.2 = f32[] constant(1)
253 broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
254 ROOT div = f32[6400]{0} divide(p0.2, broadcast)
255 }
256
257 ENTRY entry {
258 p0 = f32[6400]{0} parameter(0)
259 fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
260 fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
261 ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
262 })"))
263 .ValueOrDie();
264 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
265 SCOPED_TRACE(module->ToString());
266 const HloInstruction* fusion =
267 module->entry_computation()->root_instruction()->operand(0)->operand(0);
268 ASSERT_TRUE(fusion->IsMultiOutputFusion());
269 EXPECT_THAT(fusion->fused_expression_root(),
270 op::Tuple(op::Multiply(), op::Divide()));
271 }
272
TEST_F(MultiOutputFusionTest,MultiOutputFusionLoopReduceToInputFusion)273 TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
274 // Fusing a reduce into a loop fusion would require changing the fusion kind.
275 // That's not supported yet.
276 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
277 fused_computation_1 {
278 p0.1 = f32[6400]{0} parameter(0)
279 ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
280 }
281
282 ENTRY entry {
283 p0 = f32[6400]{0} parameter(0)
284 fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
285 const.2 = f32[] constant(0)
286 reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation
287 ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce)
288 })"))
289 .ValueOrDie();
290 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
291 }
292
TEST_F(MultiOutputFusionTest,MultiOutputFusionLoopElementwise)293 TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
294 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
295 fused_computation_1 {
296 p0.1 = f32[6400]{0} parameter(0)
297 ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
298 }
299
300 ENTRY entry {
301 p0 = f32[6400]{0} parameter(0)
302 fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
303 const.2 = f32[] constant(1)
304 broadcast = f32[6400]{0} broadcast(const.2), dimensions={}
305 div = f32[6400]{0} divide(p0, broadcast)
306 ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div)
307 })"))
308 .ValueOrDie();
309 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
310 SCOPED_TRACE(module->ToString());
311 const HloInstruction* fusion =
312 module->entry_computation()->root_instruction()->operand(0)->operand(0);
313 ASSERT_TRUE(fusion->IsMultiOutputFusion());
314 EXPECT_THAT(fusion->fused_expression_root(),
315 op::Tuple(op::Multiply(), op::Divide()));
316 }
317
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingLoopsDifferentShapes)318 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
319 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
320 fused_computation_1 {
321 p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
322 ROOT mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
323 }
324
325 fused_computation_2 {
326 p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
327 const.2 = f32[] constant(0)
328 ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2), dimensions={0,3}, to_apply=scalar_add_computation
329 }
330
331 ENTRY entry {
332 p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
333 fusion.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1
334 fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2
335 ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0}) tuple(fusion.1, fusion.2)
336 })"))
337 .ValueOrDie();
338 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
339 }
340
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingLoopAndMultiOutputLoop)341 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
342 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
343 fused_computation_1 {
344 p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
345 mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
346 exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1)
347 ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
348 f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp)
349 }
350
351 fused_computation_2 {
352 p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
353 const.2 = f32[] constant(0)
354 broadcast = f32[8,1,5,16,1,1]{5,4,3,2,1,0} broadcast(const.2),
355 dimensions={}
356 ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, broadcast)
357 }
358
359 ENTRY entry {
360 p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
361 fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
362 f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
363 calls=fused_computation_1
364 fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop,
365 calls=fused_computation_2
366 gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
367 gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
368 ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0},
369 f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0})
370 tuple(gte0, gte1, fusion.2)
371 })"))
372 .ValueOrDie();
373 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
374 SCOPED_TRACE(module->ToString());
375 const HloInstruction* fusion =
376 module->entry_computation()->root_instruction()->operand(0)->operand(0);
377 ASSERT_TRUE(fusion->IsMultiOutputFusion());
378 EXPECT_THAT(fusion->fused_expression_root(),
379 op::Tuple(op::Multiply(), op::Exp(), op::Add()));
380 }
381
TEST_F(MultiOutputFusionTest,MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes)382 TEST_F(MultiOutputFusionTest,
383 MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
384 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
385 fused_computation_1 {
386 p0.1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
387 mul = f32[8,1,5,16,1,2]{5,4,3,2,1,0} multiply(p0.1, p0.1)
388 exp = f32[8,1,5,16,1,2]{5,4,3,2,1,0} exponential(p0.1)
389 ROOT tuple = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
390 f32[8,1,5,16,1,2]{5,4,3,2,1,0}) tuple(mul, exp)
391 }
392
393 fused_computation_2 {
394 p0.2 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
395 const.2 = f32[] constant(0)
396 ROOT reduce = f32[1,5,1,2]{3,2,1,0} reduce(p0.2, const.2),
397 dimensions={0,3}, to_apply=scalar_add_computation
398 }
399
400 ENTRY entry {
401 p0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} parameter(0)
402 fusion.1 = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
403 f32[8,1,5,16,1,2]{5,4,3,2,1,0}) fusion(p0), kind=kLoop,
404 calls=fused_computation_1
405 fusion.2 = f32[1,5,1,2]{3,2,1,0} fusion(p0), kind=kLoop,
406 calls=fused_computation_2
407 gte0 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0
408 gte1 = f32[8,1,5,16,1,2]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1
409 ROOT root = (f32[8,1,5,16,1,2]{5,4,3,2,1,0},
410 f32[8,1,5,16,1,2]{5,4,3,2,1,0}, f32[1,5,1,2]{3,2,1,0})
411 tuple(gte0, gte1, fusion.2)
412 })"))
413 .ValueOrDie();
414 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
415 }
416
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionElementwiseAndReduce)417 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) {
418 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
419 ENTRY reduce {
420 p0 = f32[32,32,32]{2,1,0} parameter(0)
421 c0 = f32[] constant(0)
422 exp = f32[32,32,32]{2,1,0} exponential(p0)
423 reduce = f32[32,32]{1,0} reduce(exp, c0), dimensions={2},
424 to_apply=scalar_add_computation
425 ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, exp)
426 })"))
427 .ValueOrDie();
428 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
429 SCOPED_TRACE(module->ToString());
430 const HloInstruction* root = module->entry_computation()->root_instruction();
431 EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
432 const HloInstruction* fusion = root->operand(0)->operand(0);
433 ASSERT_TRUE(fusion->IsMultiOutputFusion());
434 EXPECT_THAT(fusion->fused_expression_root(),
435 op::Tuple(op::Reduce(), op::Exp()));
436 }
437
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionLoopFusionAndReduce)438 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
439 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
440 fused_add {
441 p0.1 = f32[32,32,32]{2,1,0} parameter(0)
442 p1.1 = f32[32,32,32]{2,1,0} parameter(1)
443 ROOT add = f32[32,32,32]{2,1,0} add(p0.1, p1.1)
444 }
445
446 ENTRY reduce {
447 p0 = f32[32,32,32]{2,1,0} parameter(0)
448 p1 = f32[32,32,32]{2,1,0} parameter(1)
449 c0 = f32[] constant(0)
450 add = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
451 reduce = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
452 to_apply=scalar_add_computation
453 ROOT root = (f32[32,32]{1,0}, f32[32,32,32]{2,1,0}) tuple(reduce, add)
454 })"))
455 .ValueOrDie();
456 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
457 SCOPED_TRACE(module->ToString());
458 const HloInstruction* root = module->entry_computation()->root_instruction();
459 EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
460 const HloInstruction* fusion = root->operand(0)->operand(0);
461 ASSERT_TRUE(fusion->IsMultiOutputFusion());
462 EXPECT_THAT(fusion->fused_expression_root(),
463 op::Tuple(op::Reduce(), op::Add()));
464 }
465
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionLoopFusionAndReduceFusion)466 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
467 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
468 fused_select {
469 p1.1 = f32[32,32,32]{2,1,0} parameter(1)
470 c0 = f32[] constant(0)
471 broadcast = f32[32,32,32]{2,1,0} broadcast(f32[] c0), dimensions={}
472 greater-than = pred[32,32,32]{2,1,0} compare(f32[32,32,32]{2,1,0} p1.1,
473 f32[32,32,32]{2,1,0} broadcast), direction=GT
474 p0.1 = f32[32,32,32]{2,1,0} parameter(0)
475 ROOT select = f32[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
476 greater-than, f32[32,32,32]{2,1,0} p0.1, f32[32,32,32]{2,1,0} broadcast)
477 }
478
479 fused_reduce {
480 p0.2 = f32[32,32,32]{2,1,0} parameter(0)
481 c1 = f32[] constant(0)
482 r1 = f32[32,32]{1,0} reduce(p0.2, c1), dimensions={2},
483 to_apply=scalar_add_computation
484 mul = f32[32,32,32]{2,1,0} multiply(p0.2, p0.2)
485 r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
486 to_apply=scalar_add_computation
487 ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
488 }
489
490 ENTRY reduce {
491 p0 = f32[32,32,32]{2,1,0} parameter(0)
492 p1 = f32[32,32,32]{2,1,0} parameter(1)
493 select = f32[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
494 fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
495 calls=fused_reduce
496 gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
497 gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
498 ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f32[32,32,32]{2,1,0})
499 tuple(gte1, gte1, select)
500 })"))
501 .ValueOrDie();
502 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
503 SCOPED_TRACE(module->ToString());
504 const HloInstruction* root = module->entry_computation()->root_instruction();
505 EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
506 op::GetTupleElement()));
507 const HloInstruction* fusion = root->operand(0)->operand(0);
508 ASSERT_TRUE(fusion->IsMultiOutputFusion());
509 EXPECT_THAT(fusion->fused_expression_root(),
510 op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
511 }
512
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionDoNotFuseLoopReduceFusion)513 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
514 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
515 fused_element_wise {
516 p0.1 = f32[2,2,2]{2,1,0} parameter(0)
517 p1.1 = f32[2,2,2]{2,1,0} parameter(1)
518 ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
519 }
520
521 fused_reduce {
522 p0.2 = f32[2,2,2]{2,1,0} parameter(0)
523 mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2,
524 f32[2,2,2]{2,1,0} p0.2)
525 broadcast = f32[2,2,2,2]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
526 c1 = f32[] constant(0)
527 ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2,2]{3,2,1,0} broadcast,
528 f32[] c1), dimensions={1,3}, to_apply=scalar_add_computation
529 }
530
531 ENTRY reduce {
532 p0 = f32[2,2,2]{2,1,0} parameter(0)
533 p1 = f32[2,2,2]{2,1,0} parameter(1)
534 element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
535 fusion = f32[2,2]{1,0} fusion(element_wise), kind=kLoop, calls=fused_reduce
536 ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
537 })"))
538 .ValueOrDie();
539 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
540 }
541
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionFp16LoopFusionAndReduceFusion)542 TEST_F(MultiOutputFusionTest,
543 ProducerConsumerFusionFp16LoopFusionAndReduceFusion) {
544 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
545 fused_select {
546 p1.1 = f16[32,32,32]{2,1,0} parameter(1)
547 c0 = f16[] constant(0)
548 broadcast = f16[32,32,32]{2,1,0} broadcast(f16[] c0), dimensions={}
549 greater-than = pred[32,32,32]{2,1,0} compare(f16[32,32,32]{2,1,0} p1.1,
550 f16[32,32,32]{2,1,0} broadcast), direction=GT
551 p0.1 = f16[32,32,32]{2,1,0} parameter(0)
552 ROOT select = f16[32,32,32]{2,1,0} select(pred[32,32,32]{2,1,0}
553 greater-than, f16[32,32,32]{2,1,0} p0.1, f16[32,32,32]{2,1,0} broadcast)
554 }
555 fused_reduce {
556 p0.2 = f16[32,32,32]{2,1,0} parameter(0)
557 convert = f32[32,32,32]{2,1,0} convert(p0.2)
558 c1 = f32[] constant(0)
559 r1 = f32[32,32]{1,0} reduce(convert, c1), dimensions={2},
560 to_apply=scalar_add_computation
561 mul = f32[32,32,32]{2,1,0} multiply(convert, convert)
562 r2 = f32[32,32]{1,0} reduce(mul, c1), dimensions={2},
563 to_apply=scalar_add_computation
564 ROOT tuple = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(r1, r2)
565 }
566 ENTRY reduce {
567 p0 = f16[32,32,32]{2,1,0} parameter(0)
568 p1 = f16[32,32,32]{2,1,0} parameter(1)
569 select = f16[32,32,32]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
570 fusion = (f32[32,32]{1,0}, f32[32,32]{1,0}) fusion(select), kind=kInput,
571 calls=fused_reduce
572 gte0 = f32[32,32]{1,0} get-tuple-element(fusion), index=0
573 gte1 = f32[32,32]{1,0} get-tuple-element(fusion), index=1
574 ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}, f16[32,32,32]{2,1,0})
575 tuple(gte1, gte1, select)
576 })"))
577 .ValueOrDie();
578 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
579 SCOPED_TRACE(module->ToString());
580 const HloInstruction* root = module->entry_computation()->root_instruction();
581 EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
582 op::GetTupleElement()));
583 const HloInstruction* fusion = root->operand(0)->operand(0);
584 ASSERT_TRUE(fusion->IsMultiOutputFusion());
585 EXPECT_THAT(fusion->fused_expression_root(),
586 op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
587 }
588
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionReduceUnfriendlyLoopFusion)589 TEST_F(MultiOutputFusionTest,
590 ProducerConsumerFusionReduceUnfriendlyLoopFusion) {
591 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
592 mixed_input_layouts_computation {
593 p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
594 p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1)
595 copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1)
596 c0 = f16[] constant(0)
597 broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={}
598 greater-than = pred[128,1024,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
599 ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
600 }
601 fused_reduce {
602 p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0)
603 convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2)
604 c0.2 = f32[] constant(0)
605 ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add_computation
606 }
607 ENTRY reduce {
608 p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0)
609 p1 = f16[128,1024,32,32]{1,3,2,0} parameter(1)
610 loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
611 reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
612 ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
613 })"))
614 .ValueOrDie();
615 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
616 }
617
TEST_F(MultiOutputFusionTest,ProducerConsumerFusionAvoidsCycles)618 TEST_F(MultiOutputFusionTest, ProducerConsumerFusionAvoidsCycles) {
619 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
620 fused_add {
621 p0 = f32[32,32,32]{2,1,0} parameter(0)
622 p1 = f32[32,32,32]{2,1,0} parameter(1)
623 ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
624 }
625
626 fused_mul {
627 p2 = f32[64,64,64]{2,1,0} parameter(0)
628 p3 = f32[64,64,64]{2,1,0} parameter(1)
629 ROOT multiply = f32[64,64,64]{2,1,0} multiply(p2, p3)
630 }
631
632 fused_reduce_1 {
633 p4 = f32[32,32,32]{2,1,0} parameter(0)
634 p5 = f32[64,64,64]{2,1,0} parameter(1)
635 slice = f32[32,32,32]{2,1,0} slice(p5), slice={[0:32], [0:32], [0:32]}
636 add = f32[32,32,32]{2,1,0} add(p4, slice)
637 c0 = f32[] constant(0)
638 ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
639 to_apply=scalar_add_computation
640 }
641
642 fused_reduce_2 {
643 p6 = f32[32,32,32]{2,1,0} parameter(0)
644 p7 = f32[64,64,64]{2,1,0} parameter(1)
645 c0 = f32[] constant(0)
646 pad = f32[64,64,64]{2,1,0} pad(p6, c0), padding=16_16x16_16x16_16
647 mul = f32[64,64,64]{2,1,0} multiply(pad, p7)
648 ROOT r1 = f32[64,64]{1,0} reduce(mul, c0), dimensions={2},
649 to_apply=scalar_add_computation
650 }
651
652 ENTRY reduce {
653 p8 = f32[32,32,32]{2,1,0} parameter(0)
654 p9 = f32[64,64,64]{2,1,0} parameter(1)
655 // `add` and `mul` can be multi-output fused with `reduce1` and `reduce2`,
656 // respectively. However, both isn't possible, because multi-output fusion
657 // will introduce an extra dependency from `neg` to `abs` or vice versa.
658 // Hence, the second multi-output fusion would introduce a cycle.
659 add = f32[32,32,32]{2,1,0} fusion(p8, p8), kind=kLoop, calls=fused_add
660 mul = f32[64,64,64]{2,1,0} fusion(p9, p9), kind=kLoop, calls=fused_mul
661
662 reduce1 = f32[32,32]{1,0} fusion(add, mul), kind=kInput,
663 calls=fused_reduce_1
664 reduce2 = f32[64,64]{1,0} fusion(add, mul), kind=kInput,
665 calls=fused_reduce_2
666 ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[64,64]{1,0},
667 f32[64,64,64]{2,1,0}) tuple(add, reduce1, reduce2, mul)
668 })"))
669 .ValueOrDie();
670 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
671 SCOPED_TRACE(module->ToString());
672 EXPECT_EQ(1, CountMultiOutputFusions(module.get()));
673 }
674
TEST_F(MultiOutputFusionTest,PreferFuseProducerIntoFusionConsumer)675 TEST_F(MultiOutputFusionTest, PreferFuseProducerIntoFusionConsumer) {
676 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
677 fused_add {
678 p0 = f32[32,32,32]{2,1,0} parameter(0)
679 p1 = f32[32,32,32]{2,1,0} parameter(1)
680 ROOT add = f32[32,32,32]{2,1,0} add(p0, p1)
681 }
682 fused_reduce {
683 p0 = f32[32,32,32]{2,1,0} parameter(0)
684 p1 = f32[64,64,64]{2,1,0} parameter(1)
685 slice = f32[32,32,32]{2,1,0} slice(p1), slice={[0:32], [0:32], [0:32]}
686 add = f32[32,32,32]{2,1,0} add(p0, slice)
687 c0 = f32[] constant(0)
688 ROOT r1 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
689 to_apply=scalar_add_computation
690 }
691 ENTRY reduce {
692 p0 = f32[32,32,32]{2,1,0} parameter(0)
693 p1 = f32[64,64,64]{2,1,0} parameter(1)
694 add = f32[32,32,32]{2,1,0} fusion(p0, p0), kind=kLoop, calls=fused_add
695 c0 = f32[] constant(0)
696 reduce2 = f32[32,32]{1,0} reduce(add, c0), dimensions={2},
697 to_apply=scalar_add_computation
698 reduce = f32[32,32]{1,0} fusion(add, p1), kind=kInput, calls=fused_reduce
699 ROOT root = (f32[32,32,32]{2,1,0}, f32[32,32]{1,0}, f32[32,32]{1,0})
700 tuple(add, reduce, reduce2)
701 })"))
702 .ValueOrDie();
703 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
704 SCOPED_TRACE(module->ToString());
705 int multi_output_fusion_count = 0;
706 for (auto* computation : module->MakeNonfusionComputations()) {
707 for (auto* instr : computation->instructions()) {
708 if (instr->IsMultiOutputFusion()) {
709 multi_output_fusion_count++;
710 }
711 }
712 }
713 EXPECT_EQ(1, multi_output_fusion_count);
714 }
715
716 // Check that we limit the number of operands to fusions we create.
TEST_F(MultiOutputFusionTest,AvoidsLargeFusion)717 TEST_F(MultiOutputFusionTest, AvoidsLargeFusion) {
718 constexpr int64_t kNumParams = 200;
719 ASSERT_GT(kNumParams, MaxOperandsAndOutputsPerFusion());
720
721 // Compute
722 // p0 * p1,
723 // p0 * p1 + p1 * p2
724 // p0 * p1 + p1 * p2 + p2 * p3
725 // ...
726 // where each of the (pi * pj)'s is represented as a fusion node so that
727 // multi-output fusion will pay attention to it.
728 auto module = CreateNewVerifiedModule();
729 HloComputation::Builder b(TestName());
730 Shape shape = ShapeUtil::MakeShape(F32, {10, 100});
731
732 std::vector<HloInstruction*> params;
733 for (int64_t i = 0; i < kNumParams; ++i) {
734 params.push_back(
735 b.AddInstruction(HloInstruction::CreateParameter(i, shape, "p")));
736 }
737
738 // Creates a fusion node that calculates x*y.
739 auto make_fusion = [&](HloInstruction* x, HloInstruction* y) {
740 HloComputation::Builder sub_builder("subcomp");
741 auto* p0 = sub_builder.AddInstruction(
742 HloInstruction::CreateParameter(0, shape, "p"));
743 auto* p1 = sub_builder.AddInstruction(
744 HloInstruction::CreateParameter(1, shape, "p"));
745 sub_builder.AddInstruction(
746 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
747 HloComputation* subcomp =
748 module->AddEmbeddedComputation(sub_builder.Build());
749 return HloInstruction::CreateFusion(
750 shape, HloInstruction::FusionKind::kLoop, {x, y}, subcomp);
751 };
752
753 auto* sum = b.AddInstruction(make_fusion(params[0], params[1]));
754 for (int64_t i = 2; i < kNumParams; ++i) {
755 sum = b.AddInstruction(HloInstruction::CreateBinary(
756 shape, HloOpcode::kAdd, sum,
757 b.AddInstruction(make_fusion(params[i - 1], params[i]))));
758 }
759 auto computation = module->AddEntryComputation(b.Build());
760 EXPECT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
761 SCOPED_TRACE(module->ToString());
762 for (const HloInstruction* instr : computation->instructions()) {
763 EXPECT_LE(instr->operand_count() + ShapeUtil::SubshapeCount(instr->shape()),
764 MaxOperandsAndOutputsPerFusion())
765 << instr->ToString();
766 }
767 }
768
TEST_F(MultiOutputFusionTest,MultiOutputFusionDUS)769 TEST_F(MultiOutputFusionTest, MultiOutputFusionDUS) {
770 auto module = ParseAndReturnVerifiedModule(R"(HloModule dus_mof
771 fusion.1 {
772 p.0 = f16[50,96,1024]{2,1,0} parameter(0)
773 p.1 = f16[1,96,1024]{2,1,0} parameter(1)
774 c.0 = s32[3]{0} constant({0, 0, 0})
775 ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
776 }
777
778 fusion.2 {
779 p.0 = f16[50,96,1024]{2,1,0} parameter(0)
780 p.1 = f16[1,96,1024]{2,1,0} parameter(1)
781 c.0 = s32[3]{0} constant({0, 0, 0})
782 ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0)
783 }
784
785 ENTRY entry {
786 p.00 = f16[50,96,1024]{2,1,0} parameter(0)
787 p.01 = f16[50,96,1024]{2,1,0} parameter(1)
788 p.1 = f16[1,96,1024]{2,1,0} parameter(2)
789
790 f1 = f16[50,96,1024] fusion(p.00, p.1), kind=kLoop, calls=fusion.1
791 f2 = f16[50,96,1024] fusion(p.01, p.1), kind=kLoop, calls=fusion.2
792 ROOT tuple = (f16[50,96,1024],f16[50,96,1024]) tuple(f1, f2)
793 })")
794 .ValueOrDie();
795 ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
796 }
797
798 // Check that we don't fuse too many reductions together.
TEST_F(MultiOutputFusionTest,SharedMemoryBudget)799 TEST_F(MultiOutputFusionTest, SharedMemoryBudget) {
800 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
801 fused_computation0 {
802 p0 = f32[64,64] parameter(0)
803 p1 = f32[64,64] parameter(1)
804 p2 = f32[] parameter(2)
805 add = f32[64,64] add(p0, p1)
806 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
807 to_apply=scalar_add_computation
808 }
809 fused_computation1 {
810 p0 = f32[64,64] parameter(0)
811 p1 = f32[64,64] parameter(1)
812 p2 = f32[] parameter(2)
813 add = f32[64,64] add(p0, p1)
814 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
815 to_apply=scalar_add_computation
816 }
817 fused_computation2 {
818 p0 = f32[64,64] parameter(0)
819 p1 = f32[64,64] parameter(1)
820 p2 = f32[] parameter(2)
821 add = f32[64,64] add(p0, p1)
822 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
823 to_apply=scalar_add_computation
824 }
825 fused_computation3 {
826 p0 = f32[64,64] parameter(0)
827 p1 = f32[64,64] parameter(1)
828 p2 = f32[] parameter(2)
829 add = f32[64,64] add(p0, p1)
830 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
831 to_apply=scalar_add_computation
832 }
833 fused_computation4 {
834 p0 = f32[64,64] parameter(0)
835 p1 = f32[64,64] parameter(1)
836 p2 = f32[] parameter(2)
837 add = f32[64,64] add(p0, p1)
838 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
839 to_apply=scalar_add_computation
840 }
841 fused_computation5 {
842 p0 = f32[64,64] parameter(0)
843 p1 = f32[64,64] parameter(1)
844 p2 = f32[] parameter(2)
845 add = f32[64,64] add(p0, p1)
846 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
847 to_apply=scalar_add_computation
848 }
849 fused_computation6 {
850 p0 = f32[64,64] parameter(0)
851 p1 = f32[64,64] parameter(1)
852 p2 = f32[] parameter(2)
853 add = f32[64,64] add(p0, p1)
854 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
855 to_apply=scalar_add_computation
856 }
857 fused_computation7 {
858 p0 = f32[64,64] parameter(0)
859 p1 = f32[64,64] parameter(1)
860 p2 = f32[] parameter(2)
861 add = f32[64,64] add(p0, p1)
862 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
863 to_apply=scalar_add_computation
864 }
865 fused_computation8 {
866 p0 = f32[64,64] parameter(0)
867 p1 = f32[64,64] parameter(1)
868 p2 = f32[] parameter(2)
869 add = f32[64,64] add(p0, p1)
870 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
871 to_apply=scalar_add_computation
872 }
873 fused_computation9 {
874 p0 = f32[64,64] parameter(0)
875 p1 = f32[64,64] parameter(1)
876 p2 = f32[] parameter(2)
877 add = f32[64,64] add(p0, p1)
878 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={0},
879 to_apply=scalar_add_computation
880 }
881 ENTRY computation {
882 zero = f32[] constant(0)
883 param0 = f32[64,64] parameter(0)
884 param1 = f32[64,64] parameter(1)
885 param2 = f32[64,64] parameter(2)
886 param3 = f32[64,64] parameter(3)
887 param4 = f32[64,64] parameter(4)
888 param5 = f32[64,64] parameter(5)
889 param6 = f32[64,64] parameter(6)
890 param7 = f32[64,64] parameter(7)
891 param8 = f32[64,64] parameter(8)
892 param9 = f32[64,64] parameter(9)
893 out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
894 out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
895 out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
896 out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
897 out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
898 out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
899 out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
900 out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
901 out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
902 out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
903 ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
904 }
905 )"))
906 .ValueOrDie();
907 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).value());
908
909 EXPECT_EQ(3, CountMultiOutputFusions(module.get()));
910 }
911
TEST_F(MultiOutputFusionTest,DoNotGroupTooManyReductions)912 TEST_F(MultiOutputFusionTest, DoNotGroupTooManyReductions) {
913 auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
914 fused_computation0 {
915 p0 = f32[64,64] parameter(0)
916 p1 = f32[64,64] parameter(1)
917 p2 = f32[] parameter(2)
918 add = f32[64,64] add(p0, p1)
919 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
920 to_apply=scalar_add_computation
921 }
922 fused_computation1 {
923 p0 = f32[64,64] parameter(0)
924 p1 = f32[64,64] parameter(1)
925 p2 = f32[] parameter(2)
926 add = f32[64,64] add(p0, p1)
927 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
928 to_apply=scalar_add_computation
929 }
930 fused_computation2 {
931 p0 = f32[64,64] parameter(0)
932 p1 = f32[64,64] parameter(1)
933 p2 = f32[] parameter(2)
934 add = f32[64,64] add(p0, p1)
935 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
936 to_apply=scalar_add_computation
937 }
938 fused_computation3 {
939 p0 = f32[64,64] parameter(0)
940 p1 = f32[64,64] parameter(1)
941 p2 = f32[] parameter(2)
942 add = f32[64,64] add(p0, p1)
943 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
944 to_apply=scalar_add_computation
945 }
946 fused_computation4 {
947 p0 = f32[64,64] parameter(0)
948 p1 = f32[64,64] parameter(1)
949 p2 = f32[] parameter(2)
950 add = f32[64,64] add(p0, p1)
951 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
952 to_apply=scalar_add_computation
953 }
954 fused_computation5 {
955 p0 = f32[64,64] parameter(0)
956 p1 = f32[64,64] parameter(1)
957 p2 = f32[] parameter(2)
958 add = f32[64,64] add(p0, p1)
959 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
960 to_apply=scalar_add_computation
961 }
962 fused_computation6 {
963 p0 = f32[64,64] parameter(0)
964 p1 = f32[64,64] parameter(1)
965 p2 = f32[] parameter(2)
966 add = f32[64,64] add(p0, p1)
967 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
968 to_apply=scalar_add_computation
969 }
970 fused_computation7 {
971 p0 = f32[64,64] parameter(0)
972 p1 = f32[64,64] parameter(1)
973 p2 = f32[] parameter(2)
974 add = f32[64,64] add(p0, p1)
975 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
976 to_apply=scalar_add_computation
977 }
978 fused_computation8 {
979 p0 = f32[64,64] parameter(0)
980 p1 = f32[64,64] parameter(1)
981 p2 = f32[] parameter(2)
982 add = f32[64,64] add(p0, p1)
983 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
984 to_apply=scalar_add_computation
985 }
986 fused_computation9 {
987 p0 = f32[64,64] parameter(0)
988 p1 = f32[64,64] parameter(1)
989 p2 = f32[] parameter(2)
990 add = f32[64,64] add(p0, p1)
991 ROOT reduce = f32[64] reduce(f32[64,64] add, f32[] p2), dimensions={1},
992 to_apply=scalar_add_computation
993 }
994 ENTRY computation {
995 zero = f32[] constant(0)
996 param0 = f32[64,64] parameter(0)
997 param1 = f32[64,64] parameter(1)
998 param2 = f32[64,64] parameter(2)
999 param3 = f32[64,64] parameter(3)
1000 param4 = f32[64,64] parameter(4)
1001 param5 = f32[64,64] parameter(5)
1002 param6 = f32[64,64] parameter(6)
1003 param7 = f32[64,64] parameter(7)
1004 param8 = f32[64,64] parameter(8)
1005 param9 = f32[64,64] parameter(9)
1006 out0 = f32[64] fusion(param0, param1, zero), kind=kInput, calls=fused_computation0
1007 out1 = f32[64] fusion(param1, param2, zero), kind=kInput, calls=fused_computation1
1008 out2 = f32[64] fusion(param2, param3, zero), kind=kInput, calls=fused_computation2
1009 out3 = f32[64] fusion(param3, param4, zero), kind=kInput, calls=fused_computation3
1010 out4 = f32[64] fusion(param4, param5, zero), kind=kInput, calls=fused_computation4
1011 out5 = f32[64] fusion(param5, param6, zero), kind=kInput, calls=fused_computation5
1012 out6 = f32[64] fusion(param6, param7, zero), kind=kInput, calls=fused_computation6
1013 out7 = f32[64] fusion(param7, param8, zero), kind=kInput, calls=fused_computation7
1014 out8 = f32[64] fusion(param8, param9, zero), kind=kInput, calls=fused_computation8
1015 out9 = f32[64] fusion(param9, param0, zero), kind=kInput, calls=fused_computation9
1016 ROOT out = (f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64], f32[64]) tuple(f32[64] out0, f32[64] out1, f32[64] out2, f32[64] out3, f32[64] out4, f32[64] out5, f32[64] out6, f32[64] out7, f32[64] out8, f32[64] out9)
1017 }
1018 )"))
1019 .ValueOrDie();
1020 ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).value());
1021
1022 EXPECT_EQ(2, CountMultiOutputFusions(module.get()));
1023 }
1024
TEST_F(MultiOutputFusionTest,NoFusionToAvoidUsingTooMuchSharedMemory)1025 TEST_F(MultiOutputFusionTest, NoFusionToAvoidUsingTooMuchSharedMemory) {
1026 auto module = ParseAndReturnVerifiedModule(R"(
1027 HloModule xla_computation_update_step.10931
1028
1029 %scalar_add_computation.1 (scalar_lhs.1: f64[], scalar_rhs.1: f64[]) -> f64[] {
1030 %scalar_lhs.1 = f64[] parameter(0)
1031 %scalar_rhs.1 = f64[] parameter(1)
1032 ROOT %add.1257 = f64[] add(f64[] %scalar_lhs.1, f64[] %scalar_rhs.1)
1033 }
1034
1035 %fused_computation.1 (param_0.8: f64[64,64], param_1.11: f64[64,64], param_2.9: f64[64,64]) -> (f64[64], f64[64]) {
1036 %param_0.8 = f64[64,64]{1,0} parameter(0)
1037 %param_1.11 = f64[64,64]{1,0} parameter(1)
1038 %multiply.2 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.8, f64[64,64]{1,0} %param_1.11)
1039 %constant_5217.3 = f64[] constant(0)
1040 %broadcast.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.3), dimensions={}
1041 %multiply.0 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.2, f64[64,64]{1,0} %broadcast.1)
1042 %reduce.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.0, f64[] %constant_5217.3), dimensions={0}, to_apply=%scalar_add_computation.1
1043 %param_2.9 = f64[64,64]{1,0} parameter(2)
1044 %multiply.1514.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_2.9, f64[64,64]{1,0} %param_1.11)
1045 %constant_5217.1.clone.1 = f64[] constant(0)
1046 %broadcast.0.clone.1 = f64[64,64]{1,0} broadcast(f64[] %constant_5217.1.clone.1), dimensions={}
1047 %multiply.1341.clone.0.clone.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %multiply.1514.clone.0.clone.1, f64[64,64]{1,0} %broadcast.0.clone.1)
1048 %reduce.630.clone.0.clone.1 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1341.clone.0.clone.1, f64[] %constant_5217.1.clone.1), dimensions={0}, to_apply=%scalar_add_computation.1
1049 ROOT %tuple = (f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %reduce.0, f64[64]{0} %reduce.630.clone.0.clone.1)
1050 }
1051
1052 %primitive_computation_add__1.6426 (parameter.6427: f64[], parameter.6428: f64[]) -> f64[] {
1053 %parameter.6427 = f64[] parameter(0)
1054 %parameter.6428 = f64[] parameter(1)
1055 ROOT %add.6429 = f64[] add(f64[] %parameter.6427, f64[] %parameter.6428)
1056 }
1057
1058 %fused_computation.2 (param_0.7: f64[64,64], param_1.9: f64[64,64]) -> f64[64] {
1059 %param_0.7 = f64[64,64]{1,0} parameter(0)
1060 %param_1.9 = f64[64,64]{1,0} parameter(1)
1061 %multiply.1 = f64[64,64]{1,0} multiply(f64[64,64]{1,0} %param_0.7, f64[64,64]{1,0} %param_1.9)
1062 %constant_5217.2 = f64[] constant(0)
1063 ROOT %reduce.740.clone.0 = f64[64]{0} reduce(f64[64,64]{1,0} %multiply.1, f64[] %constant_5217.2), dimensions={0}, to_apply=%primitive_computation_add__1.6426
1064 }
1065
1066 ENTRY %reproducer (param_0.1090: f64[64,64], param_1.1377: f64[64,64], param_2.1948: f64[64,64]) -> (f64[64], f64[64], f64[64]) {
1067 %param_0.1090 = f64[64,64]{1,0} parameter(0)
1068 %param_1.1377 = f64[64,64]{1,0} parameter(1)
1069 %param_2.1948 = f64[64,64]{1,0} parameter(2)
1070 %fusion.1 = (f64[64]{0}, f64[64]{0}) fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377, f64[64,64]{1,0} %param_2.1948), kind=kInput, calls=%fused_computation.1
1071 %get-tuple-element = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=0
1072 %fusion.2 = f64[64]{0} fusion(f64[64,64]{1,0} %param_0.1090, f64[64,64]{1,0} %param_1.1377), kind=kInput, calls=%fused_computation.2
1073 %get-tuple-element.1 = f64[64]{0} get-tuple-element((f64[64]{0}, f64[64]{0}) %fusion.1), index=1
1074 ROOT %tuple.428 = (f64[64]{0}, f64[64]{0}, f64[64]{0}) tuple(f64[64]{0} %get-tuple-element, f64[64]{0} %fusion.2, f64[64]{0} %get-tuple-element.1)
1075 }
1076 )")
1077 .ValueOrDie();
1078 EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).value());
1079 }
1080
TEST_F(MultiOutputFusionTest,NoFusionToAvoidCodeDuplication)1081 TEST_F(MultiOutputFusionTest, NoFusionToAvoidCodeDuplication) {
1082 auto module = ParseAndReturnVerifiedModule(R"(
1083 HloModule module
1084
1085 and.reduce_sub_computation {
1086 x = pred[] parameter(0)
1087 y = pred[] parameter(1)
1088 ROOT and = pred[] and(x, y)
1089 }
1090
1091 fused_computation.1 {
1092 param_4.658 = f32[2,20,256]{2,0,1} parameter(4)
1093 slice.1385 = f32[2,1,256]{2,0,1} slice(param_4.658), slice={[0:2], [11:12], [0:256]}
1094 constant.6847 = s32[] constant(0)
1095 broadcast.4823 = s32[3]{0} broadcast(constant.6847), dimensions={}
1096 param_9.415 = s32[3]{0} parameter(9)
1097 compare.700 = pred[3]{0} compare(broadcast.4823, param_9.415), direction=LE
1098 constant.6846 = pred[] constant(true)
1099 reduce.221 = pred[] reduce(compare.700, constant.6846), dimensions={0}, to_apply=and.reduce_sub_computation
1100 broadcast.2933 = pred[2,1,256]{2,0,1} broadcast(reduce.221), dimensions={}
1101 param_5.528 = f32[2,512]{1,0} parameter(5)
1102 slice.1384 = f32[2,256]{1,0} slice(param_5.528), slice={[0:2], [0:256]}
1103 bitcast.341 = f32[2,1,256]{2,0,1} bitcast(slice.1384)
1104 constant.5418 = f32[] constant(0)
1105 broadcast.3227 = f32[2,1,256]{2,0,1} broadcast(constant.5418), dimensions={}
1106 select.173 = f32[2,1,256]{2,0,1} select(broadcast.2933, bitcast.341, broadcast.3227)
1107 add.573 = f32[2,1,256]{2,0,1} add(slice.1385, select.173)
1108 param_0.299 = s32[] parameter(0)
1109 constant.5157 = s32[] constant(11)
1110 dynamic-update-slice.189 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.658, add.573, param_0.299, constant.5157, param_0.299)
1111 slice.1383 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.189), slice={[0:2], [10:11], [0:256]}
1112 constant.6800 = s32[] constant(0)
1113 broadcast.4803 = s32[3]{0} broadcast(constant.6800), dimensions={}
1114 param_8.484 = s32[3]{0} parameter(8)
1115 compare.681 = pred[3]{0} compare(broadcast.4803, param_8.484), direction=LE
1116 constant.6798 = pred[] constant(true)
1117 reduce.203 = pred[] reduce(compare.681, constant.6798), dimensions={0}, to_apply=and.reduce_sub_computation
1118 broadcast.2932 = pred[2,1,256]{2,0,1} broadcast(reduce.203), dimensions={}
1119 param_3.1169 = f32[2,512]{1,0} parameter(3)
1120 slice.1382 = f32[2,256]{1,0} slice(param_3.1169), slice={[0:2], [0:256]}
1121 bitcast.340 = f32[2,1,256]{2,0,1} bitcast(slice.1382)
1122 select.172 = f32[2,1,256]{2,0,1} select(broadcast.2932, bitcast.340, broadcast.3227)
1123 add.572 = f32[2,1,256]{2,0,1} add(slice.1383, select.172)
1124 constant.5154 = s32[] constant(10)
1125 dynamic-update-slice.188 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.189, add.572, param_0.299, constant.5154, param_0.299)
1126 slice.1381 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.188), slice={[0:2], [9:10], [0:256]}
1127 constant.6794 = s32[] constant(0)
1128 broadcast.4801 = s32[3]{0} broadcast(constant.6794), dimensions={}
1129 param_7.478 = s32[3]{0} parameter(7)
1130 compare.679 = pred[3]{0} compare(broadcast.4801, param_7.478), direction=LE
1131 constant.6793 = pred[] constant(true)
1132 reduce.201 = pred[] reduce(compare.679, constant.6793), dimensions={0}, to_apply=and.reduce_sub_computation
1133 broadcast.2930 = pred[2,1,256]{2,0,1} broadcast(reduce.201), dimensions={}
1134 param_2.1685 = f32[2,512]{1,0} parameter(2)
1135 slice.1380 = f32[2,256]{1,0} slice(param_2.1685), slice={[0:2], [0:256]}
1136 bitcast.339 = f32[2,1,256]{2,0,1} bitcast(slice.1380)
1137 select.171 = f32[2,1,256]{2,0,1} select(broadcast.2930, bitcast.339, broadcast.3227)
1138 add.571 = f32[2,1,256]{2,0,1} add(slice.1381, select.171)
1139 constant.5153 = s32[] constant(9)
1140 dynamic-update-slice.187 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.188, add.571, param_0.299, constant.5153, param_0.299)
1141 slice.1379 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.187), slice={[0:2], [8:9], [0:256]}
1142 constant.6788 = s32[] constant(0)
1143 broadcast.4799 = s32[3]{0} broadcast(constant.6788), dimensions={}
1144 param_6.495 = s32[3]{0} parameter(6)
1145 compare.677 = pred[3]{0} compare(broadcast.4799, param_6.495), direction=LE
1146 constant.6786 = pred[] constant(true)
1147 reduce.199 = pred[] reduce(compare.677, constant.6786), dimensions={0}, to_apply=and.reduce_sub_computation
1148 broadcast.2929 = pred[2,1,256]{2,0,1} broadcast(reduce.199), dimensions={}
1149 param_1.1408 = f32[2,512]{1,0} parameter(1)
1150 slice.1378 = f32[2,256]{1,0} slice(param_1.1408), slice={[0:2], [0:256]}
1151 bitcast.338 = f32[2,1,256]{2,0,1} bitcast(slice.1378)
1152 select.170 = f32[2,1,256]{2,0,1} select(broadcast.2929, bitcast.338, broadcast.3227)
1153 add.570 = f32[2,1,256]{2,0,1} add(slice.1379, select.170)
1154 constant.5152 = s32[] constant(8)
1155 ROOT dynamic-update-slice.186 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.187, add.570, param_0.299, constant.5152, param_0.299)
1156 }
1157
1158 fused_computation.2 {
1159 param_4.655 = f32[2,20,256]{2,0,1} parameter(4)
1160 slice.1369 = f32[2,1,256]{2,0,1} slice(param_4.655), slice={[0:2], [7:8], [0:256]}
1161 param_6.483 = pred[] parameter(6)
1162 broadcast.2927 = pred[2,1,256]{2,0,1} broadcast(param_6.483), dimensions={}
1163 param_5.525 = f32[2,512]{1,0} parameter(5)
1164 slice.1368 = f32[2,256]{1,0} slice(param_5.525), slice={[0:2], [0:256]}
1165 bitcast.333 = f32[2,1,256]{2,0,1} bitcast(slice.1368)
1166 constant.5415 = f32[] constant(0)
1167 broadcast.3225 = f32[2,1,256]{2,0,1} broadcast(constant.5415), dimensions={}
1168 select.161 = f32[2,1,256]{2,0,1} select(broadcast.2927, bitcast.333, broadcast.3225)
1169 add.549 = f32[2,1,256]{2,0,1} add(slice.1369, select.161)
1170 param_0.265 = s32[] parameter(0)
1171 constant.5151 = s32[] constant(7)
1172 dynamic-update-slice.185 = f32[2,20,256]{2,0,1} dynamic-update-slice(param_4.655, add.549, param_0.265, constant.5151, param_0.265)
1173 slice.1367 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.185), slice={[0:2], [6:7], [0:256]}
1174 constant.6782 = s32[] constant(0)
1175 broadcast.4797 = s32[3]{0} broadcast(constant.6782), dimensions={}
1176 param_9.391 = s32[3]{0} parameter(9)
1177 compare.675 = pred[3]{0} compare(broadcast.4797, param_9.391), direction=LE
1178 constant.6781 = pred[] constant(true)
1179 reduce.197 = pred[] reduce(compare.675, constant.6781), dimensions={0}, to_apply=and.reduce_sub_computation
1180 broadcast.2926 = pred[2,1,256]{2,0,1} broadcast(reduce.197), dimensions={}
1181 param_3.1167 = f32[2,512]{1,0} parameter(3)
1182 slice.1366 = f32[2,256]{1,0} slice(param_3.1167), slice={[0:2], [0:256]}
1183 bitcast.332 = f32[2,1,256]{2,0,1} bitcast(slice.1366)
1184 select.160 = f32[2,1,256]{2,0,1} select(broadcast.2926, bitcast.332, broadcast.3225)
1185 add.548 = f32[2,1,256]{2,0,1} add(slice.1367, select.160)
1186 constant.5150 = s32[] constant(6)
1187 dynamic-update-slice.184 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.185, add.548, param_0.265, constant.5150, param_0.265)
1188 slice.1365 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.184), slice={[0:2], [5:6], [0:256]}
1189 constant.6776 = s32[] constant(0)
1190 broadcast.4794 = s32[3]{0} broadcast(constant.6776), dimensions={}
1191 param_8.464 = s32[3]{0} parameter(8)
1192 compare.673 = pred[3]{0} compare(broadcast.4794, param_8.464), direction=LE
1193 constant.6775 = pred[] constant(true)
1194 reduce.195 = pred[] reduce(compare.673, constant.6775), dimensions={0}, to_apply=and.reduce_sub_computation
1195 broadcast.2925 = pred[2,1,256]{2,0,1} broadcast(reduce.195), dimensions={}
1196 param_2.1684 = f32[2,512]{1,0} parameter(2)
1197 slice.1364 = f32[2,256]{1,0} slice(param_2.1684), slice={[0:2], [0:256]}
1198 bitcast.331 = f32[2,1,256]{2,0,1} bitcast(slice.1364)
1199 select.159 = f32[2,1,256]{2,0,1} select(broadcast.2925, bitcast.331, broadcast.3225)
1200 add.547 = f32[2,1,256]{2,0,1} add(slice.1365, select.159)
1201 constant.5149 = s32[] constant(5)
1202 dynamic-update-slice.183 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.184, add.547, param_0.265, constant.5149, param_0.265)
1203 slice.1363 = f32[2,1,256]{2,0,1} slice(dynamic-update-slice.183), slice={[0:2], [4:5], [0:256]}
1204 constant.6770 = s32[] constant(0)
1205 broadcast.4792 = s32[3]{0} broadcast(constant.6770), dimensions={}
1206 param_7.458 = s32[3]{0} parameter(7)
1207 compare.671 = pred[3]{0} compare(broadcast.4792, param_7.458), direction=LE
1208 constant.6769 = pred[] constant(true)
1209 reduce.193 = pred[] reduce(compare.671, constant.6769), dimensions={0}, to_apply=and.reduce_sub_computation
1210 broadcast.2924 = pred[2,1,256]{2,0,1} broadcast(reduce.193), dimensions={}
1211 param_1.1405 = f32[2,512]{1,0} parameter(1)
1212 slice.1362 = f32[2,256]{1,0} slice(param_1.1405), slice={[0:2], [0:256]}
1213 bitcast.330 = f32[2,1,256]{2,0,1} bitcast(slice.1362)
1214 select.158 = f32[2,1,256]{2,0,1} select(broadcast.2924, bitcast.330, broadcast.3225)
1215 add.546 = f32[2,1,256]{2,0,1} add(slice.1363, select.158)
1216 constant.5148 = s32[] constant(4)
1217 ROOT dynamic-update-slice.182 = f32[2,20,256]{2,0,1} dynamic-update-slice(dynamic-update-slice.183, add.546, param_0.265, constant.5148, param_0.265)
1218 }
1219
1220 ENTRY main {
1221 param_0.0 = s32[] parameter(0)
1222 param_1.0 = f32[2,512]{1,0} parameter(1)
1223 param_2.0 = f32[2,512]{1,0} parameter(2)
1224 param_3.0 = f32[2,512]{1,0} parameter(3)
1225 param_4.0 = f32[2,20,256]{2,1,0} parameter(4)
1226 param_5.0 = f32[2,512]{1,0} parameter(5)
1227 param_6.0 = s32[3]{0} parameter(6)
1228 param_7.0 = s32[3]{0} parameter(7)
1229 param_8.0 = s32[3]{0} parameter(8)
1230 param_9.0 = s32[3]{0} parameter(9)
1231 fusion.1 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, param_4.0, param_5.0, param_6.0, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.1
1232 param_10 = pred[] parameter(10)
1233 fusion.2 = f32[2,20,256]{2,0,1} fusion(param_0.0, param_1.0, param_2.0, param_3.0, fusion.1, param_5.0, param_10, param_7.0, param_8.0, param_9.0), kind=kLoop, calls=fused_computation.2
1234 ROOT root = (f32[2,20,256]{2,0,1}, f32[2,20,256]{2,0,1}) tuple(fusion.1, fusion.2)
1235 }
1236 )")
1237 .ValueOrDie();
1238 EXPECT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
1239 }
1240
1241 } // namespace gpu
1242 } // namespace xla
1243