1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/transpose_folding.h"
26 #include "tensorflow/compiler/xla/shape.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/tests/test_utils.h"
29
30 namespace op = xla::testing::opcode_matchers;
31
32 namespace xla {
33 namespace cpu {
34 namespace {
35
36 using InstructionFusionTest = HloTestBase;
37
MakeDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)38 std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
39 HloInstruction* rhs) {
40 DotDimensionNumbers dot_dnums;
41 dot_dnums.add_lhs_contracting_dimensions(lhs->shape().rank() - 1);
42 dot_dnums.add_rhs_contracting_dimensions(0);
43 PrecisionConfig precision_config;
44 precision_config.mutable_operand_precision()->Resize(
45 2, PrecisionConfig::DEFAULT);
46 return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
47 precision_config);
48 }
49
TEST_F(InstructionFusionTest,DotOperationFusion_Basic_0)50 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
51 HloComputation::Builder builder(TestName());
52 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
53 0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
54 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
55 1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
56
57 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
58 ShapeUtil::MakeShape(F32, {1024, 256}), HloOpcode::kExp, arg0));
59 HloInstruction* dot = builder.AddInstruction(
60 MakeDot(ShapeUtil::MakeShape(F32, {1024}), exp0, arg1));
61
62 auto module = CreateNewVerifiedModule();
63 auto computation = module->AddEntryComputation(builder.Build());
64 EXPECT_EQ(dot, computation->root_instruction());
65 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
66 EXPECT_THAT(computation->root_instruction(), op::Fusion());
67 }
68
TEST_F(InstructionFusionTest,DotOperationFusion_Basic_1)69 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
70 HloComputation::Builder builder(TestName());
71 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
72 0, ShapeUtil::MakeShape(F32, {256}), "arg0"));
73 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
74 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
75
76 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
77 ShapeUtil::MakeShape(F32, {256, 1024}), HloOpcode::kExp, arg1));
78 HloInstruction* dot = builder.AddInstruction(
79 MakeDot(ShapeUtil::MakeShape(F32, {1024}), arg0, exp1));
80
81 auto module = CreateNewVerifiedModule();
82 auto computation = module->AddEntryComputation(builder.Build());
83 EXPECT_EQ(dot, computation->root_instruction());
84 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
85 EXPECT_THAT(computation->root_instruction(), op::Fusion());
86 }
87
TEST_F(InstructionFusionTest,DotOperationFusion_Bitcast)88 TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
89 HloComputation::Builder builder(TestName());
90 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
91 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
92 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
93 1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
94
95 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
96 ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
97 HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
98 ShapeUtil::MakeShape(F32, {1024, 256}), HloOpcode::kBitcast, exp0));
99 HloInstruction* dot = builder.AddInstruction(
100 MakeDot(ShapeUtil::MakeShape(F32, {1024}), bitcast0, arg1));
101
102 auto module = CreateNewVerifiedModule();
103 auto computation = module->AddEntryComputation(builder.Build());
104 EXPECT_EQ(dot, computation->root_instruction());
105 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
106 EXPECT_THAT(computation->root_instruction(), op::Fusion());
107 }
108
TEST_F(InstructionFusionTest,DotOperationFusion_Reshape)109 TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
110 HloComputation::Builder builder(TestName());
111 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
112 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
113 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
114 1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
115
116 HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
117 ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
118 HloInstruction* reshape0 =
119 builder.AddInstruction(HloInstruction::CreateReshape(
120 ShapeUtil::MakeShape(F32, {1024, 256}), exp0));
121 HloInstruction* dot = builder.AddInstruction(
122 MakeDot(ShapeUtil::MakeShape(F32, {1024}), reshape0, arg1));
123
124 auto module = CreateNewVerifiedModule();
125 auto computation = module->AddEntryComputation(builder.Build());
126 EXPECT_EQ(dot, computation->root_instruction());
127 EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
128 EXPECT_THAT(computation->root_instruction(), op::Fusion());
129 }
130
TEST_F(InstructionFusionTest,DotOperationFusion_TooLarge)131 TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
132 HloComputation::Builder builder(TestName());
133 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
134 0, ShapeUtil::MakeShape(F32, {32 * 1024}), "arg0"));
135 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
136 1, ShapeUtil::MakeShape(F32, {32 * 1024, 256}), "arg1"));
137
138 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
139 ShapeUtil::MakeShape(F32, {32 * 1024, 256}), HloOpcode::kExp, arg1));
140 HloInstruction* dot = builder.AddInstruction(
141 MakeDot(ShapeUtil::MakeShape(F32, {256}), arg0, exp1));
142
143 auto module = CreateNewVerifiedModule();
144 auto computation = module->AddEntryComputation(builder.Build());
145 EXPECT_EQ(dot, computation->root_instruction());
146 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
147 EXPECT_EQ(dot, computation->root_instruction());
148 }
149
TEST_F(InstructionFusionTest,DotOperationFusion_ElementReuse)150 TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) {
151 HloComputation::Builder builder(TestName());
152 HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
153 0, ShapeUtil::MakeShape(F32, {2, 256}), "arg0"));
154 HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
155 1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
156
157 HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
158 ShapeUtil::MakeShape(F32, {256, 1024}), HloOpcode::kExp, arg1));
159 HloInstruction* dot = builder.AddInstruction(
160 MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1));
161
162 auto module = CreateNewVerifiedModule();
163 auto computation = module->AddEntryComputation(builder.Build());
164 EXPECT_EQ(dot, computation->root_instruction());
165 EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
166 EXPECT_EQ(dot, computation->root_instruction());
167 }
168
TEST_F(InstructionFusionTest,DotOperationFusion_TransposeFusion_RHS)169 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) {
170 std::string hlo_string = R"(
171 HloModule DotOperationFusion_TransposeFusion
172
173 ENTRY DotOperationFusion_TransposeFusion {
174 arg0 = f32[1,256] parameter(0)
175 arg1 = f32[1024,256] parameter(1)
176 exponential = f32[1024,256] exponential(arg1)
177 transpose = f32[256,1024] transpose(exponential), dimensions={1,0}
178 ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
179 }
180 )";
181
182 TF_ASSERT_OK_AND_ASSIGN(auto module,
183 ParseAndReturnVerifiedModule(hlo_string));
184 HloComputation* computation = module->entry_computation();
185
186 TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
187 ASSERT_TRUE(changed);
188 ASSERT_THAT(computation->root_instruction(),
189 op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
190 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
191 }
192
TEST_F(InstructionFusionTest,DotOperationFusion_TransposeFusion_LHS)193 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) {
194 std::string hlo_string = R"(
195 HloModule DotOperationFusion_TransposeFusion
196
197 ENTRY DotOperationFusion_TransposeFusion {
198 arg0 = f32[256,1] parameter(0)
199 arg1 = f32[256,1024] parameter(1)
200 transpose = f32[1,256] transpose(arg0), dimensions={1,0}
201 exponential = f32[256,1024] exponential(arg1)
202 ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0}
203 }
204 )";
205
206 TF_ASSERT_OK_AND_ASSIGN(auto module,
207 ParseAndReturnVerifiedModule(hlo_string));
208 HloComputation* computation = module->entry_computation();
209
210 TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
211 ASSERT_TRUE(changed);
212 ASSERT_THAT(computation->root_instruction(),
213 op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
214 /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0));
215 }
216
TEST_F(InstructionFusionTest,DotOperationFusion_TransposeFusion_LHS_NonDefault)217 TEST_F(InstructionFusionTest,
218 DotOperationFusion_TransposeFusion_LHS_NonDefault) {
219 std::string hlo_string = R"(
220 HloModule DotOperationFusion_TransposeFusion
221
222 ENTRY DotOperationFusion_TransposeFusion {
223 arg0 = f32[1,256] parameter(0)
224 arg1 = f32[256,1024] parameter(1)
225 transpose = f32[256,1] transpose(arg0), dimensions={1,0}
226 exponential = f32[256,1024] exponential(arg1)
227 ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0}
228 }
229 )";
230
231 TF_ASSERT_OK_AND_ASSIGN(auto module,
232 ParseAndReturnVerifiedModule(hlo_string));
233 HloComputation* computation = module->entry_computation();
234
235 TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
236 ASSERT_TRUE(changed);
237 ASSERT_THAT(computation->root_instruction(),
238 op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
239 /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0));
240 }
241
242 class OpcodeFusionTest : public InstructionFusionTest {
243 protected:
244 // Runs CPU instruction fusion on the given module, and tests that the result
245 // contains a fused op at the root with exactly the given multiset of opcodes.
RunFusionAndCheckOpcodesWereFused(HloModule * module,const std::multiset<HloOpcode> & expected_opcodes,HloInstruction::FusionKind fusion_kind=HloInstruction::FusionKind::kLoop)246 void RunFusionAndCheckOpcodesWereFused(
247 HloModule* module, const std::multiset<HloOpcode>& expected_opcodes,
248 HloInstruction::FusionKind fusion_kind =
249 HloInstruction::FusionKind::kLoop) {
250 auto computation = module->entry_computation();
251 auto did_fusion = CpuInstructionFusion().Run(module);
252 ASSERT_TRUE(did_fusion.ok());
253 EXPECT_TRUE(did_fusion.ValueOrDie());
254
255 HloInstruction* root = computation->root_instruction();
256 ASSERT_THAT(root, op::Fusion());
257 EXPECT_EQ(root->fusion_kind(), fusion_kind);
258
259 std::vector<HloOpcode> fused_opcodes(root->fused_instruction_count());
260 std::transform(root->fused_instructions().begin(),
261 root->fused_instructions().end(), fused_opcodes.begin(),
262 [](const HloInstruction* hlo) { return hlo->opcode(); });
263
264 EXPECT_EQ(
265 std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
266 expected_opcodes);
267 }
268
CreateAdderToOne(HloModule * module)269 HloComputation* CreateAdderToOne(HloModule* module) {
270 HloComputation::Builder builder(TestName());
271 HloInstruction* arg0 =
272 builder.AddInstruction(HloInstruction::CreateParameter(
273 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
274 HloInstruction* one = builder.AddInstruction(
275 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
276 builder.AddInstruction(HloInstruction::CreateBinary(
277 ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
278 return module->AddEmbeddedComputation(builder.Build());
279 }
280
CreateMax(HloModule * module)281 HloComputation* CreateMax(HloModule* module) {
282 HloComputation::Builder builder(TestName());
283 HloInstruction* arg0 =
284 builder.AddInstruction(HloInstruction::CreateParameter(
285 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
286 HloInstruction* arg1 =
287 builder.AddInstruction(HloInstruction::CreateParameter(
288 1, ShapeUtil::MakeShape(F32, {}), "arg1"));
289 builder.AddInstruction(HloInstruction::CreateBinary(
290 ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
291 return module->AddEmbeddedComputation(builder.Build());
292 }
293 };
294
TEST_F(OpcodeFusionTest,Exponential_Reshape_Negate)295 TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) {
296 HloComputation::Builder builder(TestName());
297 Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4});
298 Shape result_shape = ShapeUtil::MakeShape(F32, {4});
299 HloInstruction* param0 = builder.AddInstruction(
300 HloInstruction::CreateParameter(0, param_shape, "param"));
301 HloInstruction* exp1 = builder.AddInstruction(
302 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
303 HloInstruction* reshape2 =
304 builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1));
305 builder.AddInstruction(
306 HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2));
307
308 auto module = CreateNewVerifiedModule();
309 module->AddEntryComputation(builder.Build());
310
311 RunFusionAndCheckOpcodesWereFused(
312 module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp,
313 HloOpcode::kParameter});
314 }
315
TEST_F(OpcodeFusionTest,Broadcast_Reshape_DynamicSlice_Tanh)316 TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) {
317 HloComputation::Builder builder(TestName());
318 Shape param_shape = ShapeUtil::MakeShape(F32, {8});
319 Shape starts_shape = ShapeUtil::MakeShape(S32, {});
320 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8});
321 Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8});
322 Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4});
323 HloInstruction* param0 = builder.AddInstruction(
324 HloInstruction::CreateParameter(0, param_shape, "param"));
325 HloInstruction* param1 = builder.AddInstruction(
326 HloInstruction::CreateParameter(1, starts_shape, "starts"));
327 HloInstruction* param2 = builder.AddInstruction(
328 HloInstruction::CreateParameter(2, starts_shape, "starts"));
329 HloInstruction* broadcast2 = builder.AddInstruction(
330 HloInstruction::CreateBroadcast(broadcast_shape, param0, {1}));
331 HloInstruction* reshape3 = builder.AddInstruction(
332 HloInstruction::CreateReshape(reshape_shape, broadcast2));
333 HloInstruction* dynamic_slice4 =
334 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
335 dynamic_slice_shape, reshape3, {param1, param2}, {4, 4}));
336 builder.AddInstruction(HloInstruction::CreateUnary(
337 dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4));
338
339 auto module = CreateNewVerifiedModule();
340 module->AddEntryComputation(builder.Build());
341
342 RunFusionAndCheckOpcodesWereFused(
343 module.get(),
344 {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape,
345 HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter,
346 HloOpcode::kParameter});
347 }
348
TEST_F(OpcodeFusionTest,Broadcast_Negate)349 TEST_F(OpcodeFusionTest, Broadcast_Negate) {
350 HloComputation::Builder builder(TestName());
351 Shape param_shape = ShapeUtil::MakeShape(F32, {8});
352 Shape result_shape = ShapeUtil::MakeShape(F32, {8, 8});
353 HloInstruction* param0 = builder.AddInstruction(
354 HloInstruction::CreateParameter(0, param_shape, "param"));
355 HloInstruction* broadcast1 = builder.AddInstruction(
356 HloInstruction::CreateBroadcast(result_shape, param0, {1}));
357 builder.AddInstruction(HloInstruction::CreateUnary(
358 result_shape, HloOpcode::kNegate, broadcast1));
359
360 auto module = CreateNewVerifiedModule();
361 module->AddEntryComputation(builder.Build());
362
363 RunFusionAndCheckOpcodesWereFused(
364 module.get(),
365 {HloOpcode::kNegate, HloOpcode::kBroadcast, HloOpcode::kParameter});
366 }
367
TEST_F(OpcodeFusionTest,DynamicSlice_Negate)368 TEST_F(OpcodeFusionTest, DynamicSlice_Negate) {
369 HloComputation::Builder builder(TestName());
370 Shape param_shape = ShapeUtil::MakeShape(F32, {4});
371 Shape slice_shape = ShapeUtil::MakeShape(S32, {});
372 Shape result_shape = ShapeUtil::MakeShape(F32, {2});
373 HloInstruction* param0 = builder.AddInstruction(
374 HloInstruction::CreateParameter(0, param_shape, "param"));
375 HloInstruction* param1 = builder.AddInstruction(
376 HloInstruction::CreateParameter(1, slice_shape, "starts"));
377 HloInstruction* dynamic_slice2 = builder.AddInstruction(
378 HloInstruction::CreateDynamicSlice(result_shape, param0, {param1}, {2}));
379 builder.AddInstruction(HloInstruction::CreateUnary(
380 result_shape, HloOpcode::kNegate, dynamic_slice2));
381
382 auto module = CreateNewVerifiedModule();
383 module->AddEntryComputation(builder.Build());
384
385 RunFusionAndCheckOpcodesWereFused(
386 module.get(), {HloOpcode::kNegate, HloOpcode::kDynamicSlice,
387 HloOpcode::kParameter, HloOpcode::kParameter});
388 }
389
TEST_F(OpcodeFusionTest,Exponential_Negate)390 TEST_F(OpcodeFusionTest, Exponential_Negate) {
391 HloComputation::Builder builder(TestName());
392 Shape param_shape = ShapeUtil::MakeShape(F32, {4});
393 HloInstruction* param0 = builder.AddInstruction(
394 HloInstruction::CreateParameter(0, param_shape, "param"));
395 HloInstruction* exp1 = builder.AddInstruction(
396 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
397 builder.AddInstruction(
398 HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1));
399
400 auto module = CreateNewVerifiedModule();
401 module->AddEntryComputation(builder.Build());
402
403 RunFusionAndCheckOpcodesWereFused(
404 module.get(),
405 {HloOpcode::kNegate, HloOpcode::kExp, HloOpcode::kParameter});
406 }
407
TEST_F(OpcodeFusionTest,Reshape_Negate)408 TEST_F(OpcodeFusionTest, Reshape_Negate) {
409 HloComputation::Builder builder(TestName());
410 Shape param_shape = ShapeUtil::MakeShape(F32, {4, 4});
411 Shape result_shape = ShapeUtil::MakeShape(F32, {16});
412 HloInstruction* param0 = builder.AddInstruction(
413 HloInstruction::CreateParameter(0, param_shape, "param"));
414 HloInstruction* reshape1 = builder.AddInstruction(
415 HloInstruction::CreateReshape(result_shape, param0));
416 builder.AddInstruction(
417 HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1));
418
419 auto module = CreateNewVerifiedModule();
420 module->AddEntryComputation(builder.Build());
421
422 RunFusionAndCheckOpcodesWereFused(
423 module.get(),
424 {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kParameter});
425 }
426
TEST_F(OpcodeFusionTest,Reverse_Negate)427 TEST_F(OpcodeFusionTest, Reverse_Negate) {
428 HloComputation::Builder builder(TestName());
429 Shape param_shape = ShapeUtil::MakeShape(F32, {8});
430 HloInstruction* param0 = builder.AddInstruction(
431 HloInstruction::CreateParameter(0, param_shape, "param"));
432 HloInstruction* reverse1 = builder.AddInstruction(
433 HloInstruction::CreateReverse(param_shape, param0, {0}));
434 builder.AddInstruction(
435 HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1));
436
437 auto module = CreateNewVerifiedModule();
438 module->AddEntryComputation(builder.Build());
439
440 RunFusionAndCheckOpcodesWereFused(
441 module.get(),
442 {HloOpcode::kNegate, HloOpcode::kReverse, HloOpcode::kParameter});
443 }
444
TEST_F(OpcodeFusionTest,Slice_Negate)445 TEST_F(OpcodeFusionTest, Slice_Negate) {
446 HloComputation::Builder builder(TestName());
447 Shape param_shape = ShapeUtil::MakeShape(F32, {4});
448 Shape slice_shape = ShapeUtil::MakeShape(F32, {2});
449 HloInstruction* param0 = builder.AddInstruction(
450 HloInstruction::CreateParameter(0, param_shape, "param"));
451 HloInstruction* slice1 = builder.AddInstruction(
452 HloInstruction::CreateSlice(slice_shape, param0, {0}, {4}, {2}));
453 builder.AddInstruction(HloInstruction::CreateUnary(
454 ShapeUtil::MakeShape(F32, {2}), HloOpcode::kNegate, slice1));
455
456 auto module = CreateNewVerifiedModule();
457 module->AddEntryComputation(builder.Build());
458
459 RunFusionAndCheckOpcodesWereFused(
460 module.get(),
461 {HloOpcode::kNegate, HloOpcode::kSlice, HloOpcode::kParameter});
462 }
463
TEST_F(OpcodeFusionTest,Exponential_Transpose_Negate)464 TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
465 HloComputation::Builder builder(TestName());
466 Shape param_shape = ShapeUtil::MakeShape(F32, {3, 4});
467 Shape result_shape = ShapeUtil::MakeShape(F32, {4, 3});
468 HloInstruction* param0 = builder.AddInstruction(
469 HloInstruction::CreateParameter(0, param_shape, "param"));
470 // InstructionFusion::ShouldFuse() precludes fusing a transpose whose operand
471 // is a parameter, so create an operand between the parameter and transpose.
472 HloInstruction* exp1 = builder.AddInstruction(
473 HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
474 HloInstruction* transpose2 = builder.AddInstruction(
475 HloInstruction::CreateTranspose(result_shape, exp1, {1, 0}));
476 builder.AddInstruction(HloInstruction::CreateUnary(
477 result_shape, HloOpcode::kNegate, transpose2));
478
479 auto module = CreateNewVerifiedModule();
480 module->AddEntryComputation(builder.Build());
481
482 RunFusionAndCheckOpcodesWereFused(
483 module.get(), {HloOpcode::kNegate, HloOpcode::kTranspose, HloOpcode::kExp,
484 HloOpcode::kParameter});
485 }
486
TEST_F(OpcodeFusionTest,UnaryMapOfExp)487 TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
488 auto module = CreateNewVerifiedModule();
489
490 HloComputation::Builder builder(TestName());
491 Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
492 HloInstruction* param0 = builder.AddInstruction(
493 HloInstruction::CreateParameter(0, shape, "param"));
494
495 HloInstruction* exp = builder.AddInstruction(
496 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
497 builder.AddInstruction(
498 HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get())));
499
500 module->AddEntryComputation(builder.Build());
501
502 RunFusionAndCheckOpcodesWereFused(
503 module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
504 }
505
TEST_F(OpcodeFusionTest,BinaryMapOfExps)506 TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
507 auto module = CreateNewVerifiedModule();
508
509 HloComputation::Builder builder(TestName());
510 Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
511 HloInstruction* param0 = builder.AddInstruction(
512 HloInstruction::CreateParameter(0, shape, "param"));
513 HloInstruction* param1 = builder.AddInstruction(
514 HloInstruction::CreateParameter(1, shape, "param"));
515
516 HloInstruction* exp0 = builder.AddInstruction(
517 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
518 HloInstruction* exp1 = builder.AddInstruction(
519 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
520
521 builder.AddInstruction(
522 HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get())));
523
524 module->AddEntryComputation(builder.Build());
525
526 RunFusionAndCheckOpcodesWereFused(
527 module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
528 HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
529 }
530
TEST_F(OpcodeFusionTest,DynamicSliceWithDynamicUpdateSlice)531 TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) {
532 auto module = CreateNewVerifiedModule();
533
534 HloComputation::Builder builder(TestName());
535 Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
536 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
537
538 std::vector<HloInstruction*> slice_indices, update_indices;
539 for (int i = 0; i < 3; ++i) {
540 slice_indices.push_back(
541 builder.AddInstruction(HloInstruction::CreateParameter(
542 1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
543 update_indices.push_back(
544 builder.AddInstruction(HloInstruction::CreateParameter(
545 5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices")));
546 }
547 HloInstruction* slice =
548 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
549 slice_shape,
550 builder.AddInstruction(
551 HloInstruction::CreateParameter(0, full_shape, "slice_from")),
552 slice_indices,
553 /*slice_sizes=*/{10, 1, 1000}));
554
555 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
556 full_shape,
557 builder.AddInstruction(
558 HloInstruction::CreateParameter(4, full_shape, "to_update")),
559 slice, update_indices));
560
561 module->AddEntryComputation(builder.Build());
562 RunFusionAndCheckOpcodesWereFused(
563 module.get(),
564 {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice,
565 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter,
566 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter,
567 HloOpcode::kParameter, HloOpcode::kParameter});
568 }
569
TEST_F(OpcodeFusionTest,MessOfFusibleNodes)570 TEST_F(OpcodeFusionTest, MessOfFusibleNodes) {
571 auto module = CreateNewVerifiedModule();
572 HloComputation::Builder builder(TestName());
573
574 Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50});
575
576 auto loop_idx = builder.AddInstruction(HloInstruction::CreateParameter(
577 0, ShapeUtil::MakeShape(S32, {}), "param0"));
578 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
579 1, ShapeUtil::MakeShape(S32, {}), "param1"));
580
581 auto idx_choice = builder.AddInstruction(HloInstruction::CreateReshape(
582 ShapeUtil::MakeShape(S32, {}),
583 builder.AddInstruction(HloInstruction::CreateDynamicSlice(
584 ShapeUtil::MakeShape(S32, {1}),
585 builder.AddInstruction(HloInstruction::CreateParameter(
586 2, ShapeUtil::MakeShape(S32, {4}), "param2")),
587 {loop_idx},
588 /*slice_sizes=*/{1}))));
589 auto zero = builder.AddInstruction(
590 HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
591
592 auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
593 ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}),
594 builder.AddInstruction(HloInstruction::CreateParameter(
595 3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")),
596 {idx_choice, zero, zero, zero, zero},
597 /*slice_sizes=*/{1, 100, 10, 100, 50}));
598
599 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
600 full_shape,
601 builder.AddInstruction(
602 HloInstruction::CreateParameter(4, full_shape, "param4")),
603 slice, {loop_idx, param1, param1, param1, param1}));
604
605 module->AddEntryComputation(builder.Build());
606 RunFusionAndCheckOpcodesWereFused(
607 module.get(),
608 {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice,
609 HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape,
610 HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter,
611 HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter});
612 }
613
CreateComputationForDotAddOutputFusionTest(const std::string & test_name,HloModule * module,int m,int k,int n,bool add_extra_use_for_dot)614 void CreateComputationForDotAddOutputFusionTest(const std::string& test_name,
615 HloModule* module, int m, int k,
616 int n,
617 bool add_extra_use_for_dot) {
618 HloComputation::Builder builder(test_name);
619
620 Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
621 Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
622 Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
623 if (m == 1) {
624 dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
625 dot_shape = ShapeUtil::MakeShape(F32, {n});
626 } else if (n == 1) {
627 dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
628 dot_shape = ShapeUtil::MakeShape(F32, {m});
629 }
630
631 auto* dot_lhs = builder.AddInstruction(
632 HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
633 auto* dot_rhs = builder.AddInstruction(
634 HloInstruction::CreateParameter(1, dot_rhs_shape, "param1"));
635 auto* addend = builder.AddInstruction(
636 HloInstruction::CreateParameter(2, dot_shape, "param2"));
637
638 auto* dot =
639 builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
640 builder.AddInstruction(
641 HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
642
643 if (add_extra_use_for_dot) {
644 auto* token = builder.AddInstruction(HloInstruction::CreateToken());
645 builder.AddInstruction(
646 HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config"));
647 }
648
649 module->AddEntryComputation(builder.Build());
650 }
651
TEST_F(OpcodeFusionTest,DotAddOutputFusion_1x50x19)652 TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) {
653 auto module = CreateNewVerifiedModule();
654 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1,
655 /*k=*/50, /*n=*/19,
656 /*add_extra_use_for_dot=*/false);
657
658 RunFusionAndCheckOpcodesWereFused(
659 module.get(),
660 {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
661 HloOpcode::kParameter, HloOpcode::kParameter},
662 HloInstruction::FusionKind::kOutput);
663 }
664
TEST_F(OpcodeFusionTest,DotAddOutputFusion_19x50x1)665 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) {
666 auto module = CreateNewVerifiedModule();
667 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
668 /*k=*/50, /*n=*/1,
669 /*add_extra_use_for_dot=*/false);
670
671 RunFusionAndCheckOpcodesWereFused(
672 module.get(),
673 {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
674 HloOpcode::kParameter, HloOpcode::kParameter},
675 HloInstruction::FusionKind::kOutput);
676 }
677
TEST_F(OpcodeFusionTest,DotAddOutputFusion_19x50x19)678 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) {
679 auto module = CreateNewVerifiedModule();
680 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
681 /*k=*/50, /*n=*/19,
682 /*add_extra_use_for_dot=*/false);
683
684 TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
685 CpuInstructionFusion().Run(module.get()));
686 EXPECT_FALSE(fused_something);
687 EXPECT_THAT(module->entry_computation()->root_instruction(),
688 Not(op::Fusion()));
689 }
690
TEST_F(OpcodeFusionTest,DotAddOutputFusion_19x50x1_multi_use)691 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
692 auto module = CreateNewVerifiedModule();
693 CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
694 /*k=*/50, /*n=*/1,
695 /*add_extra_use_for_dot=*/true);
696
697 TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
698 CpuInstructionFusion().Run(module.get()));
699 EXPECT_FALSE(fused_something);
700 EXPECT_THAT(module->entry_computation()->root_instruction(),
701 Not(op::Fusion()));
702 }
703
TEST_F(InstructionFusionTest,DotOperationFusion_DontOutputFuseDuplicateOperands)704 TEST_F(InstructionFusionTest,
705 DotOperationFusion_DontOutputFuseDuplicateOperands) {
706 absl::string_view module_string = R"(
707 HloModule module
708
709 ENTRY main {
710 a = f32[50,60]{1,0} parameter(0)
711 b = f32[60,1]{1,0} parameter(1)
712 c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
713 ROOT d = f32[50,1]{1,0} add(c, c)
714 }
715 )";
716
717 TF_ASSERT_OK_AND_ASSIGN(auto module,
718 ParseAndReturnVerifiedModule(module_string));
719 TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
720 CpuInstructionFusion().Run(module.get()));
721 EXPECT_FALSE(fused_something);
722 EXPECT_THAT(module->entry_computation()->root_instruction(),
723 Not(op::Fusion()));
724 }
725
726 struct GatherLoopFusionTestSpec {
727 std::string test_name;
728 std::string hlo_computation_text;
729
Namexla::cpu::__anon8130c97e0111::GatherLoopFusionTestSpec730 static std::string Name(
731 const ::testing::TestParamInfo<GatherLoopFusionTestSpec>& info) {
732 return info.param.test_name;
733 }
734 };
735
736 class GatherLoopFusionTest
737 : public OpcodeFusionTest,
738 public ::testing::WithParamInterface<GatherLoopFusionTestSpec> {};
739
TEST_P(GatherLoopFusionTest,GatherLoopFusion)740 TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
741 const GatherLoopFusionTestSpec& spec = GetParam();
742 std::string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n",
743 spec.hlo_computation_text);
744 TF_ASSERT_OK_AND_ASSIGN(auto module,
745 ParseAndReturnVerifiedModule(hlo_string));
746
747 RunFusionAndCheckOpcodesWereFused(
748 module.get(),
749 {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast,
750 HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter});
751 }
752
GetGatherLoopFusionTestSpecs()753 std::vector<GatherLoopFusionTestSpec> GetGatherLoopFusionTestSpecs() {
754 std::vector<GatherLoopFusionTestSpec> result;
755
756 result.push_back({"FusedTensorFlowGatherV2", R"(
757 ENTRY main {
758 operand = s32[3,3] parameter(0)
759 indices = s32[2] parameter(1)
760 gather = s32[3,2] gather(operand, indices),
761 offset_dims={0},
762 collapsed_slice_dims={1},
763 start_index_map={1},
764 index_vector_dim=1,
765 slice_sizes={3, 1}
766 one = s32[] constant(1)
767 one_broadcasted = s32[3,2] broadcast(one), dimensions={}
768 ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
769 }
770 )"});
771
772 result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"(
773 ENTRY main {
774 operand = s32[3,3] parameter(0)
775 indices = s32[2,2] parameter(1)
776 gather = s32[2,3,2] gather(operand, indices),
777 offset_dims={1},
778 collapsed_slice_dims={1},
779 start_index_map={1},
780 index_vector_dim=2,
781 slice_sizes={3, 1}
782 one = s32[] constant(1)
783 one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
784 ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
785 }
786 )"});
787
788 result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"(
789 ENTRY main {
790 operand = s32[3,3] parameter(0)
791 indices = s32[2,2,2] parameter(1)
792 gather = s32[2,2] gather(operand, indices),
793 offset_dims={},
794 collapsed_slice_dims={0,1},
795 start_index_map={0,1},
796 index_vector_dim=2,
797 slice_sizes={1, 1}
798 one = s32[] constant(1)
799 one_broadcasted = s32[2,2] broadcast(one), dimensions={}
800 ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
801 }
802 )"});
803
804 result.push_back({"FusedTensorFlowGatherNd_0", R"(
805 ENTRY main {
806 operand = s32[3,3,2] parameter(0)
807 indices = s32[2,2] parameter(1)
808 gather = s32[2,2] gather(operand, indices),
809 offset_dims={1},
810 collapsed_slice_dims={0,1},
811 start_index_map={0,1},
812 index_vector_dim=1,
813 slice_sizes={1,1,2}
814 one = s32[] constant(1)
815 one_broadcasted = s32[2,2] broadcast(one), dimensions={}
816 ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
817 }
818 )"});
819
820 result.push_back({"FusedTensorFlowGatherNd_1", R"(
821 ENTRY main {
822 operand = s32[3,3,2] parameter(0)
823 indices = s32[2,2] parameter(1)
824 gather = s32[2,2] gather(operand, indices),
825 offset_dims={1},
826 collapsed_slice_dims={0,1},
827 start_index_map={0,1},
828 index_vector_dim=0,
829 slice_sizes={1,1,2}
830 one = s32[] constant(1)
831 one_broadcasted = s32[2,2] broadcast(one), dimensions={}
832 ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
833 }
834 )"});
835
836 result.push_back({"FusedDynamicSlice", R"(
837 ENTRY main {
838 operand = s32[3,3] parameter(0)
839 indices = s32[2] parameter(1)
840 gather = s32[1,1] gather(operand, indices),
841 offset_dims={0,1},
842 collapsed_slice_dims={},
843 start_index_map={0,1},
844 index_vector_dim=0,
845 slice_sizes={1,1}
846 one = s32[] constant(1)
847 one_broadcasted = s32[1,1] broadcast(one), dimensions={}
848 ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
849 }
850 )"});
851
852 result.push_back({"FusedBatchDynamicSlice", R"(
853 ENTRY main {
854 operand = s32[3,3] parameter(0)
855 indices = s32[2,2] parameter(1)
856 gather = s32[2,1,1] gather(operand, indices),
857 offset_dims={1,2},
858 collapsed_slice_dims={},
859 start_index_map={0,1},
860 index_vector_dim=0,
861 slice_sizes={1,1}
862 one = s32[] constant(1)
863 one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
864 ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
865 }
866 )"});
867
868 return result;
869 }
870
871 INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation,
872 GatherLoopFusionTest,
873 ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
874 GatherLoopFusionTestSpec::Name);
875
TEST_F(InstructionFusionTest,NoFuseReduceMajor)876 TEST_F(InstructionFusionTest, NoFuseReduceMajor) {
877 absl::string_view module_string = R"(
878 HloModule module
879
880 add {
881 lhs = f32[] parameter(0)
882 rhs = f32[] parameter(1)
883 ROOT add = f32[] add(lhs, rhs)
884 }
885
886 ENTRY main {
887 a = f32[50,60]{1,0} parameter(0)
888 b = f32[50,60]{1,0} parameter(1)
889 c = f32[50,60]{1,0} add(a, b)
890 init = f32[] constant(0)
891 ROOT r = f32[60]{0} reduce(c, init), dimensions={0}, to_apply=add
892 }
893 )";
894
895 TF_ASSERT_OK_AND_ASSIGN(auto module,
896 ParseAndReturnVerifiedModule(module_string));
897 TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
898 CpuInstructionFusion().Run(module.get()));
899 EXPECT_FALSE(fused_something);
900 EXPECT_THAT(module->entry_computation()->root_instruction(),
901 Not(op::Fusion()));
902 }
903
TEST_F(InstructionFusionTest,FuseReduceMinor)904 TEST_F(InstructionFusionTest, FuseReduceMinor) {
905 absl::string_view module_string = R"(
906 HloModule module
907
908 add {
909 lhs = f32[] parameter(0)
910 rhs = f32[] parameter(1)
911 ROOT add = f32[] add(lhs, rhs)
912 }
913
914 ENTRY main {
915 a = f32[50,60]{1,0} parameter(0)
916 b = f32[50,60]{1,0} parameter(1)
917 c = f32[50,60]{1,0} add(a, b)
918 init = f32[] constant(0)
919 ROOT r = f32[] reduce(c, init), dimensions={0,1}, to_apply=add
920 }
921 )";
922
923 TF_ASSERT_OK_AND_ASSIGN(auto module,
924 ParseAndReturnVerifiedModule(module_string));
925 TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
926 CpuInstructionFusion().Run(module.get()));
927 EXPECT_TRUE(fused_something);
928 EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
929 }
930 } // namespace
931 } // namespace cpu
932 } // namespace xla
933