• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/transpose_folding.h"
26 #include "tensorflow/compiler/xla/shape.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/tests/test_utils.h"
29 
30 namespace op = xla::testing::opcode_matchers;
31 
32 namespace xla {
33 namespace cpu {
34 namespace {
35 
36 using InstructionFusionTest = HloTestBase;
37 
MakeDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs)38 std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
39                                         HloInstruction* rhs) {
40   DotDimensionNumbers dot_dnums;
41   dot_dnums.add_lhs_contracting_dimensions(lhs->shape().rank() - 1);
42   dot_dnums.add_rhs_contracting_dimensions(0);
43   PrecisionConfig precision_config;
44   precision_config.mutable_operand_precision()->Resize(
45       2, PrecisionConfig::DEFAULT);
46   return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
47                                    precision_config);
48 }
49 
TEST_F(InstructionFusionTest,DotOperationFusion_Basic_0)50 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
51   HloComputation::Builder builder(TestName());
52   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
53       0, ShapeUtil::MakeShape(F32, {1024, 256}), "arg0"));
54   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
55       1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
56 
57   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
58       ShapeUtil::MakeShape(F32, {1024, 256}), HloOpcode::kExp, arg0));
59   HloInstruction* dot = builder.AddInstruction(
60       MakeDot(ShapeUtil::MakeShape(F32, {1024}), exp0, arg1));
61 
62   auto module = CreateNewVerifiedModule();
63   auto computation = module->AddEntryComputation(builder.Build());
64   EXPECT_EQ(dot, computation->root_instruction());
65   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
66   EXPECT_THAT(computation->root_instruction(), op::Fusion());
67 }
68 
TEST_F(InstructionFusionTest,DotOperationFusion_Basic_1)69 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
70   HloComputation::Builder builder(TestName());
71   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
72       0, ShapeUtil::MakeShape(F32, {256}), "arg0"));
73   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
74       1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
75 
76   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
77       ShapeUtil::MakeShape(F32, {256, 1024}), HloOpcode::kExp, arg1));
78   HloInstruction* dot = builder.AddInstruction(
79       MakeDot(ShapeUtil::MakeShape(F32, {1024}), arg0, exp1));
80 
81   auto module = CreateNewVerifiedModule();
82   auto computation = module->AddEntryComputation(builder.Build());
83   EXPECT_EQ(dot, computation->root_instruction());
84   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
85   EXPECT_THAT(computation->root_instruction(), op::Fusion());
86 }
87 
TEST_F(InstructionFusionTest,DotOperationFusion_Bitcast)88 TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
89   HloComputation::Builder builder(TestName());
90   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
91       0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
92   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
93       1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
94 
95   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
96       ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
97   HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
98       ShapeUtil::MakeShape(F32, {1024, 256}), HloOpcode::kBitcast, exp0));
99   HloInstruction* dot = builder.AddInstruction(
100       MakeDot(ShapeUtil::MakeShape(F32, {1024}), bitcast0, arg1));
101 
102   auto module = CreateNewVerifiedModule();
103   auto computation = module->AddEntryComputation(builder.Build());
104   EXPECT_EQ(dot, computation->root_instruction());
105   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
106   EXPECT_THAT(computation->root_instruction(), op::Fusion());
107 }
108 
TEST_F(InstructionFusionTest,DotOperationFusion_Reshape)109 TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
110   HloComputation::Builder builder(TestName());
111   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
112       0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0"));
113   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
114       1, ShapeUtil::MakeShape(F32, {256}), "arg1"));
115 
116   HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
117       ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
118   HloInstruction* reshape0 =
119       builder.AddInstruction(HloInstruction::CreateReshape(
120           ShapeUtil::MakeShape(F32, {1024, 256}), exp0));
121   HloInstruction* dot = builder.AddInstruction(
122       MakeDot(ShapeUtil::MakeShape(F32, {1024}), reshape0, arg1));
123 
124   auto module = CreateNewVerifiedModule();
125   auto computation = module->AddEntryComputation(builder.Build());
126   EXPECT_EQ(dot, computation->root_instruction());
127   EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
128   EXPECT_THAT(computation->root_instruction(), op::Fusion());
129 }
130 
TEST_F(InstructionFusionTest,DotOperationFusion_TooLarge)131 TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
132   HloComputation::Builder builder(TestName());
133   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
134       0, ShapeUtil::MakeShape(F32, {32 * 1024}), "arg0"));
135   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
136       1, ShapeUtil::MakeShape(F32, {32 * 1024, 256}), "arg1"));
137 
138   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
139       ShapeUtil::MakeShape(F32, {32 * 1024, 256}), HloOpcode::kExp, arg1));
140   HloInstruction* dot = builder.AddInstruction(
141       MakeDot(ShapeUtil::MakeShape(F32, {256}), arg0, exp1));
142 
143   auto module = CreateNewVerifiedModule();
144   auto computation = module->AddEntryComputation(builder.Build());
145   EXPECT_EQ(dot, computation->root_instruction());
146   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
147   EXPECT_EQ(dot, computation->root_instruction());
148 }
149 
TEST_F(InstructionFusionTest,DotOperationFusion_ElementReuse)150 TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) {
151   HloComputation::Builder builder(TestName());
152   HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
153       0, ShapeUtil::MakeShape(F32, {2, 256}), "arg0"));
154   HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter(
155       1, ShapeUtil::MakeShape(F32, {256, 1024}), "arg1"));
156 
157   HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
158       ShapeUtil::MakeShape(F32, {256, 1024}), HloOpcode::kExp, arg1));
159   HloInstruction* dot = builder.AddInstruction(
160       MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1));
161 
162   auto module = CreateNewVerifiedModule();
163   auto computation = module->AddEntryComputation(builder.Build());
164   EXPECT_EQ(dot, computation->root_instruction());
165   EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie());
166   EXPECT_EQ(dot, computation->root_instruction());
167 }
168 
TEST_F(InstructionFusionTest,DotOperationFusion_TransposeFusion_RHS)169 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) {
170   std::string hlo_string = R"(
171 HloModule DotOperationFusion_TransposeFusion
172 
173 ENTRY DotOperationFusion_TransposeFusion {
174   arg0 = f32[1,256] parameter(0)
175   arg1 = f32[1024,256] parameter(1)
176   exponential = f32[1024,256] exponential(arg1)
177   transpose = f32[256,1024] transpose(exponential), dimensions={1,0}
178   ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
179 }
180 )";
181 
182   TF_ASSERT_OK_AND_ASSIGN(auto module,
183                           ParseAndReturnVerifiedModule(hlo_string));
184   HloComputation* computation = module->entry_computation();
185 
186   TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
187   ASSERT_TRUE(changed);
188   ASSERT_THAT(computation->root_instruction(),
189               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
190                       /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1));
191 }
192 
TEST_F(InstructionFusionTest,DotOperationFusion_TransposeFusion_LHS)193 TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) {
194   std::string hlo_string = R"(
195 HloModule DotOperationFusion_TransposeFusion
196 
197 ENTRY DotOperationFusion_TransposeFusion {
198   arg0 = f32[256,1] parameter(0)
199   arg1 = f32[256,1024] parameter(1)
200   transpose = f32[1,256] transpose(arg0), dimensions={1,0}
201   exponential = f32[256,1024] exponential(arg1)
202   ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0}
203 }
204 )";
205 
206   TF_ASSERT_OK_AND_ASSIGN(auto module,
207                           ParseAndReturnVerifiedModule(hlo_string));
208   HloComputation* computation = module->entry_computation();
209 
210   TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
211   ASSERT_TRUE(changed);
212   ASSERT_THAT(computation->root_instruction(),
213               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
214                       /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0));
215 }
216 
TEST_F(InstructionFusionTest,DotOperationFusion_TransposeFusion_LHS_NonDefault)217 TEST_F(InstructionFusionTest,
218        DotOperationFusion_TransposeFusion_LHS_NonDefault) {
219   std::string hlo_string = R"(
220 HloModule DotOperationFusion_TransposeFusion
221 
222 ENTRY DotOperationFusion_TransposeFusion {
223   arg0 = f32[1,256] parameter(0)
224   arg1 = f32[256,1024] parameter(1)
225   transpose = f32[256,1] transpose(arg0), dimensions={1,0}
226   exponential = f32[256,1024] exponential(arg1)
227   ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0}
228 }
229 )";
230 
231   TF_ASSERT_OK_AND_ASSIGN(auto module,
232                           ParseAndReturnVerifiedModule(hlo_string));
233   HloComputation* computation = module->entry_computation();
234 
235   TF_ASSERT_OK_AND_ASSIGN(bool changed, TransposeFolding().Run(module.get()));
236   ASSERT_TRUE(changed);
237   ASSERT_THAT(computation->root_instruction(),
238               op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)),
239                       /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0));
240 }
241 
242 class OpcodeFusionTest : public InstructionFusionTest {
243  protected:
244   // Runs CPU instruction fusion on the given module, and tests that the result
245   // contains a fused op at the root with exactly the given multiset of opcodes.
RunFusionAndCheckOpcodesWereFused(HloModule * module,const std::multiset<HloOpcode> & expected_opcodes,HloInstruction::FusionKind fusion_kind=HloInstruction::FusionKind::kLoop)246   void RunFusionAndCheckOpcodesWereFused(
247       HloModule* module, const std::multiset<HloOpcode>& expected_opcodes,
248       HloInstruction::FusionKind fusion_kind =
249           HloInstruction::FusionKind::kLoop) {
250     auto computation = module->entry_computation();
251     auto did_fusion = CpuInstructionFusion().Run(module);
252     ASSERT_TRUE(did_fusion.ok());
253     EXPECT_TRUE(did_fusion.ValueOrDie());
254 
255     HloInstruction* root = computation->root_instruction();
256     ASSERT_THAT(root, op::Fusion());
257     EXPECT_EQ(root->fusion_kind(), fusion_kind);
258 
259     std::vector<HloOpcode> fused_opcodes(root->fused_instruction_count());
260     std::transform(root->fused_instructions().begin(),
261                    root->fused_instructions().end(), fused_opcodes.begin(),
262                    [](const HloInstruction* hlo) { return hlo->opcode(); });
263 
264     EXPECT_EQ(
265         std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
266         expected_opcodes);
267   }
268 
CreateAdderToOne(HloModule * module)269   HloComputation* CreateAdderToOne(HloModule* module) {
270     HloComputation::Builder builder(TestName());
271     HloInstruction* arg0 =
272         builder.AddInstruction(HloInstruction::CreateParameter(
273             0, ShapeUtil::MakeShape(F32, {}), "arg0"));
274     HloInstruction* one = builder.AddInstruction(
275         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
276     builder.AddInstruction(HloInstruction::CreateBinary(
277         ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
278     return module->AddEmbeddedComputation(builder.Build());
279   }
280 
CreateMax(HloModule * module)281   HloComputation* CreateMax(HloModule* module) {
282     HloComputation::Builder builder(TestName());
283     HloInstruction* arg0 =
284         builder.AddInstruction(HloInstruction::CreateParameter(
285             0, ShapeUtil::MakeShape(F32, {}), "arg0"));
286     HloInstruction* arg1 =
287         builder.AddInstruction(HloInstruction::CreateParameter(
288             1, ShapeUtil::MakeShape(F32, {}), "arg1"));
289     builder.AddInstruction(HloInstruction::CreateBinary(
290         ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
291     return module->AddEmbeddedComputation(builder.Build());
292   }
293 };
294 
TEST_F(OpcodeFusionTest,Exponential_Reshape_Negate)295 TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) {
296   HloComputation::Builder builder(TestName());
297   Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4});
298   Shape result_shape = ShapeUtil::MakeShape(F32, {4});
299   HloInstruction* param0 = builder.AddInstruction(
300       HloInstruction::CreateParameter(0, param_shape, "param"));
301   HloInstruction* exp1 = builder.AddInstruction(
302       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
303   HloInstruction* reshape2 =
304       builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1));
305   builder.AddInstruction(
306       HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2));
307 
308   auto module = CreateNewVerifiedModule();
309   module->AddEntryComputation(builder.Build());
310 
311   RunFusionAndCheckOpcodesWereFused(
312       module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp,
313                      HloOpcode::kParameter});
314 }
315 
TEST_F(OpcodeFusionTest,Broadcast_Reshape_DynamicSlice_Tanh)316 TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) {
317   HloComputation::Builder builder(TestName());
318   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
319   Shape starts_shape = ShapeUtil::MakeShape(S32, {});
320   Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8});
321   Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8});
322   Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4});
323   HloInstruction* param0 = builder.AddInstruction(
324       HloInstruction::CreateParameter(0, param_shape, "param"));
325   HloInstruction* param1 = builder.AddInstruction(
326       HloInstruction::CreateParameter(1, starts_shape, "starts"));
327   HloInstruction* param2 = builder.AddInstruction(
328       HloInstruction::CreateParameter(2, starts_shape, "starts"));
329   HloInstruction* broadcast2 = builder.AddInstruction(
330       HloInstruction::CreateBroadcast(broadcast_shape, param0, {1}));
331   HloInstruction* reshape3 = builder.AddInstruction(
332       HloInstruction::CreateReshape(reshape_shape, broadcast2));
333   HloInstruction* dynamic_slice4 =
334       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
335           dynamic_slice_shape, reshape3, {param1, param2}, {4, 4}));
336   builder.AddInstruction(HloInstruction::CreateUnary(
337       dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4));
338 
339   auto module = CreateNewVerifiedModule();
340   module->AddEntryComputation(builder.Build());
341 
342   RunFusionAndCheckOpcodesWereFused(
343       module.get(),
344       {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape,
345        HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter,
346        HloOpcode::kParameter});
347 }
348 
TEST_F(OpcodeFusionTest,Broadcast_Negate)349 TEST_F(OpcodeFusionTest, Broadcast_Negate) {
350   HloComputation::Builder builder(TestName());
351   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
352   Shape result_shape = ShapeUtil::MakeShape(F32, {8, 8});
353   HloInstruction* param0 = builder.AddInstruction(
354       HloInstruction::CreateParameter(0, param_shape, "param"));
355   HloInstruction* broadcast1 = builder.AddInstruction(
356       HloInstruction::CreateBroadcast(result_shape, param0, {1}));
357   builder.AddInstruction(HloInstruction::CreateUnary(
358       result_shape, HloOpcode::kNegate, broadcast1));
359 
360   auto module = CreateNewVerifiedModule();
361   module->AddEntryComputation(builder.Build());
362 
363   RunFusionAndCheckOpcodesWereFused(
364       module.get(),
365       {HloOpcode::kNegate, HloOpcode::kBroadcast, HloOpcode::kParameter});
366 }
367 
TEST_F(OpcodeFusionTest,DynamicSlice_Negate)368 TEST_F(OpcodeFusionTest, DynamicSlice_Negate) {
369   HloComputation::Builder builder(TestName());
370   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
371   Shape slice_shape = ShapeUtil::MakeShape(S32, {});
372   Shape result_shape = ShapeUtil::MakeShape(F32, {2});
373   HloInstruction* param0 = builder.AddInstruction(
374       HloInstruction::CreateParameter(0, param_shape, "param"));
375   HloInstruction* param1 = builder.AddInstruction(
376       HloInstruction::CreateParameter(1, slice_shape, "starts"));
377   HloInstruction* dynamic_slice2 = builder.AddInstruction(
378       HloInstruction::CreateDynamicSlice(result_shape, param0, {param1}, {2}));
379   builder.AddInstruction(HloInstruction::CreateUnary(
380       result_shape, HloOpcode::kNegate, dynamic_slice2));
381 
382   auto module = CreateNewVerifiedModule();
383   module->AddEntryComputation(builder.Build());
384 
385   RunFusionAndCheckOpcodesWereFused(
386       module.get(), {HloOpcode::kNegate, HloOpcode::kDynamicSlice,
387                      HloOpcode::kParameter, HloOpcode::kParameter});
388 }
389 
TEST_F(OpcodeFusionTest,Exponential_Negate)390 TEST_F(OpcodeFusionTest, Exponential_Negate) {
391   HloComputation::Builder builder(TestName());
392   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
393   HloInstruction* param0 = builder.AddInstruction(
394       HloInstruction::CreateParameter(0, param_shape, "param"));
395   HloInstruction* exp1 = builder.AddInstruction(
396       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
397   builder.AddInstruction(
398       HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, exp1));
399 
400   auto module = CreateNewVerifiedModule();
401   module->AddEntryComputation(builder.Build());
402 
403   RunFusionAndCheckOpcodesWereFused(
404       module.get(),
405       {HloOpcode::kNegate, HloOpcode::kExp, HloOpcode::kParameter});
406 }
407 
TEST_F(OpcodeFusionTest,Reshape_Negate)408 TEST_F(OpcodeFusionTest, Reshape_Negate) {
409   HloComputation::Builder builder(TestName());
410   Shape param_shape = ShapeUtil::MakeShape(F32, {4, 4});
411   Shape result_shape = ShapeUtil::MakeShape(F32, {16});
412   HloInstruction* param0 = builder.AddInstruction(
413       HloInstruction::CreateParameter(0, param_shape, "param"));
414   HloInstruction* reshape1 = builder.AddInstruction(
415       HloInstruction::CreateReshape(result_shape, param0));
416   builder.AddInstruction(
417       HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape1));
418 
419   auto module = CreateNewVerifiedModule();
420   module->AddEntryComputation(builder.Build());
421 
422   RunFusionAndCheckOpcodesWereFused(
423       module.get(),
424       {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kParameter});
425 }
426 
TEST_F(OpcodeFusionTest,Reverse_Negate)427 TEST_F(OpcodeFusionTest, Reverse_Negate) {
428   HloComputation::Builder builder(TestName());
429   Shape param_shape = ShapeUtil::MakeShape(F32, {8});
430   HloInstruction* param0 = builder.AddInstruction(
431       HloInstruction::CreateParameter(0, param_shape, "param"));
432   HloInstruction* reverse1 = builder.AddInstruction(
433       HloInstruction::CreateReverse(param_shape, param0, {0}));
434   builder.AddInstruction(
435       HloInstruction::CreateUnary(param_shape, HloOpcode::kNegate, reverse1));
436 
437   auto module = CreateNewVerifiedModule();
438   module->AddEntryComputation(builder.Build());
439 
440   RunFusionAndCheckOpcodesWereFused(
441       module.get(),
442       {HloOpcode::kNegate, HloOpcode::kReverse, HloOpcode::kParameter});
443 }
444 
TEST_F(OpcodeFusionTest,Slice_Negate)445 TEST_F(OpcodeFusionTest, Slice_Negate) {
446   HloComputation::Builder builder(TestName());
447   Shape param_shape = ShapeUtil::MakeShape(F32, {4});
448   Shape slice_shape = ShapeUtil::MakeShape(F32, {2});
449   HloInstruction* param0 = builder.AddInstruction(
450       HloInstruction::CreateParameter(0, param_shape, "param"));
451   HloInstruction* slice1 = builder.AddInstruction(
452       HloInstruction::CreateSlice(slice_shape, param0, {0}, {4}, {2}));
453   builder.AddInstruction(HloInstruction::CreateUnary(
454       ShapeUtil::MakeShape(F32, {2}), HloOpcode::kNegate, slice1));
455 
456   auto module = CreateNewVerifiedModule();
457   module->AddEntryComputation(builder.Build());
458 
459   RunFusionAndCheckOpcodesWereFused(
460       module.get(),
461       {HloOpcode::kNegate, HloOpcode::kSlice, HloOpcode::kParameter});
462 }
463 
TEST_F(OpcodeFusionTest,Exponential_Transpose_Negate)464 TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
465   HloComputation::Builder builder(TestName());
466   Shape param_shape = ShapeUtil::MakeShape(F32, {3, 4});
467   Shape result_shape = ShapeUtil::MakeShape(F32, {4, 3});
468   HloInstruction* param0 = builder.AddInstruction(
469       HloInstruction::CreateParameter(0, param_shape, "param"));
470   // InstructionFusion::ShouldFuse() precludes fusing a transpose whose operand
471   // is a parameter, so create an operand between the parameter and transpose.
472   HloInstruction* exp1 = builder.AddInstruction(
473       HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0));
474   HloInstruction* transpose2 = builder.AddInstruction(
475       HloInstruction::CreateTranspose(result_shape, exp1, {1, 0}));
476   builder.AddInstruction(HloInstruction::CreateUnary(
477       result_shape, HloOpcode::kNegate, transpose2));
478 
479   auto module = CreateNewVerifiedModule();
480   module->AddEntryComputation(builder.Build());
481 
482   RunFusionAndCheckOpcodesWereFused(
483       module.get(), {HloOpcode::kNegate, HloOpcode::kTranspose, HloOpcode::kExp,
484                      HloOpcode::kParameter});
485 }
486 
TEST_F(OpcodeFusionTest,UnaryMapOfExp)487 TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
488   auto module = CreateNewVerifiedModule();
489 
490   HloComputation::Builder builder(TestName());
491   Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
492   HloInstruction* param0 = builder.AddInstruction(
493       HloInstruction::CreateParameter(0, shape, "param"));
494 
495   HloInstruction* exp = builder.AddInstruction(
496       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
497   builder.AddInstruction(
498       HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get())));
499 
500   module->AddEntryComputation(builder.Build());
501 
502   RunFusionAndCheckOpcodesWereFused(
503       module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
504 }
505 
TEST_F(OpcodeFusionTest,BinaryMapOfExps)506 TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
507   auto module = CreateNewVerifiedModule();
508 
509   HloComputation::Builder builder(TestName());
510   Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
511   HloInstruction* param0 = builder.AddInstruction(
512       HloInstruction::CreateParameter(0, shape, "param"));
513   HloInstruction* param1 = builder.AddInstruction(
514       HloInstruction::CreateParameter(1, shape, "param"));
515 
516   HloInstruction* exp0 = builder.AddInstruction(
517       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
518   HloInstruction* exp1 = builder.AddInstruction(
519       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
520 
521   builder.AddInstruction(
522       HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get())));
523 
524   module->AddEntryComputation(builder.Build());
525 
526   RunFusionAndCheckOpcodesWereFused(
527       module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
528                      HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
529 }
530 
TEST_F(OpcodeFusionTest,DynamicSliceWithDynamicUpdateSlice)531 TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) {
532   auto module = CreateNewVerifiedModule();
533 
534   HloComputation::Builder builder(TestName());
535   Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000});
536   Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000});
537 
538   std::vector<HloInstruction*> slice_indices, update_indices;
539   for (int i = 0; i < 3; ++i) {
540     slice_indices.push_back(
541         builder.AddInstruction(HloInstruction::CreateParameter(
542             1 + i, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
543     update_indices.push_back(
544         builder.AddInstruction(HloInstruction::CreateParameter(
545             5 + i, ShapeUtil::MakeShape(U32, {}), "update_indices")));
546   }
547   HloInstruction* slice =
548       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
549           slice_shape,
550           builder.AddInstruction(
551               HloInstruction::CreateParameter(0, full_shape, "slice_from")),
552           slice_indices,
553           /*slice_sizes=*/{10, 1, 1000}));
554 
555   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
556       full_shape,
557       builder.AddInstruction(
558           HloInstruction::CreateParameter(4, full_shape, "to_update")),
559       slice, update_indices));
560 
561   module->AddEntryComputation(builder.Build());
562   RunFusionAndCheckOpcodesWereFused(
563       module.get(),
564       {HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice,
565        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter,
566        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter,
567        HloOpcode::kParameter, HloOpcode::kParameter});
568 }
569 
TEST_F(OpcodeFusionTest,MessOfFusibleNodes)570 TEST_F(OpcodeFusionTest, MessOfFusibleNodes) {
571   auto module = CreateNewVerifiedModule();
572   HloComputation::Builder builder(TestName());
573 
574   Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50});
575 
576   auto loop_idx = builder.AddInstruction(HloInstruction::CreateParameter(
577       0, ShapeUtil::MakeShape(S32, {}), "param0"));
578   auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
579       1, ShapeUtil::MakeShape(S32, {}), "param1"));
580 
581   auto idx_choice = builder.AddInstruction(HloInstruction::CreateReshape(
582       ShapeUtil::MakeShape(S32, {}),
583       builder.AddInstruction(HloInstruction::CreateDynamicSlice(
584           ShapeUtil::MakeShape(S32, {1}),
585           builder.AddInstruction(HloInstruction::CreateParameter(
586               2, ShapeUtil::MakeShape(S32, {4}), "param2")),
587           {loop_idx},
588           /*slice_sizes=*/{1}))));
589   auto zero = builder.AddInstruction(
590       HloInstruction::CreateConstant(LiteralUtil::CreateR0(0)));
591 
592   auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
593       ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}),
594       builder.AddInstruction(HloInstruction::CreateParameter(
595           3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")),
596       {idx_choice, zero, zero, zero, zero},
597       /*slice_sizes=*/{1, 100, 10, 100, 50}));
598 
599   builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
600       full_shape,
601       builder.AddInstruction(
602           HloInstruction::CreateParameter(4, full_shape, "param4")),
603       slice, {loop_idx, param1, param1, param1, param1}));
604 
605   module->AddEntryComputation(builder.Build());
606   RunFusionAndCheckOpcodesWereFused(
607       module.get(),
608       {HloOpcode::kDynamicSlice, HloOpcode::kDynamicSlice,
609        HloOpcode::kDynamicUpdateSlice, HloOpcode::kReshape,
610        HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter,
611        HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter});
612 }
613 
CreateComputationForDotAddOutputFusionTest(const std::string & test_name,HloModule * module,int m,int k,int n,bool add_extra_use_for_dot)614 void CreateComputationForDotAddOutputFusionTest(const std::string& test_name,
615                                                 HloModule* module, int m, int k,
616                                                 int n,
617                                                 bool add_extra_use_for_dot) {
618   HloComputation::Builder builder(test_name);
619 
620   Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
621   Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
622   Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
623   if (m == 1) {
624     dot_lhs_shape = ShapeUtil::MakeShape(F32, {k});
625     dot_shape = ShapeUtil::MakeShape(F32, {n});
626   } else if (n == 1) {
627     dot_rhs_shape = ShapeUtil::MakeShape(F32, {k});
628     dot_shape = ShapeUtil::MakeShape(F32, {m});
629   }
630 
631   auto* dot_lhs = builder.AddInstruction(
632       HloInstruction::CreateParameter(0, dot_lhs_shape, "param0"));
633   auto* dot_rhs = builder.AddInstruction(
634       HloInstruction::CreateParameter(1, dot_rhs_shape, "param1"));
635   auto* addend = builder.AddInstruction(
636       HloInstruction::CreateParameter(2, dot_shape, "param2"));
637 
638   auto* dot =
639       builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
640   builder.AddInstruction(
641       HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
642 
643   if (add_extra_use_for_dot) {
644     auto* token = builder.AddInstruction(HloInstruction::CreateToken());
645     builder.AddInstruction(
646         HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config"));
647   }
648 
649   module->AddEntryComputation(builder.Build());
650 }
651 
TEST_F(OpcodeFusionTest,DotAddOutputFusion_1x50x19)652 TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) {
653   auto module = CreateNewVerifiedModule();
654   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1,
655                                              /*k=*/50, /*n=*/19,
656                                              /*add_extra_use_for_dot=*/false);
657 
658   RunFusionAndCheckOpcodesWereFused(
659       module.get(),
660       {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
661        HloOpcode::kParameter, HloOpcode::kParameter},
662       HloInstruction::FusionKind::kOutput);
663 }
664 
TEST_F(OpcodeFusionTest,DotAddOutputFusion_19x50x1)665 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) {
666   auto module = CreateNewVerifiedModule();
667   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
668                                              /*k=*/50, /*n=*/1,
669                                              /*add_extra_use_for_dot=*/false);
670 
671   RunFusionAndCheckOpcodesWereFused(
672       module.get(),
673       {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter,
674        HloOpcode::kParameter, HloOpcode::kParameter},
675       HloInstruction::FusionKind::kOutput);
676 }
677 
TEST_F(OpcodeFusionTest,DotAddOutputFusion_19x50x19)678 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) {
679   auto module = CreateNewVerifiedModule();
680   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
681                                              /*k=*/50, /*n=*/19,
682                                              /*add_extra_use_for_dot=*/false);
683 
684   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
685                           CpuInstructionFusion().Run(module.get()));
686   EXPECT_FALSE(fused_something);
687   EXPECT_THAT(module->entry_computation()->root_instruction(),
688               Not(op::Fusion()));
689 }
690 
TEST_F(OpcodeFusionTest,DotAddOutputFusion_19x50x1_multi_use)691 TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) {
692   auto module = CreateNewVerifiedModule();
693   CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19,
694                                              /*k=*/50, /*n=*/1,
695                                              /*add_extra_use_for_dot=*/true);
696 
697   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
698                           CpuInstructionFusion().Run(module.get()));
699   EXPECT_FALSE(fused_something);
700   EXPECT_THAT(module->entry_computation()->root_instruction(),
701               Not(op::Fusion()));
702 }
703 
TEST_F(InstructionFusionTest,DotOperationFusion_DontOutputFuseDuplicateOperands)704 TEST_F(InstructionFusionTest,
705        DotOperationFusion_DontOutputFuseDuplicateOperands) {
706   absl::string_view module_string = R"(
707 HloModule module
708 
709 ENTRY main {
710   a = f32[50,60]{1,0} parameter(0)
711   b = f32[60,1]{1,0} parameter(1)
712   c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
713   ROOT d = f32[50,1]{1,0} add(c, c)
714 }
715 )";
716 
717   TF_ASSERT_OK_AND_ASSIGN(auto module,
718                           ParseAndReturnVerifiedModule(module_string));
719   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
720                           CpuInstructionFusion().Run(module.get()));
721   EXPECT_FALSE(fused_something);
722   EXPECT_THAT(module->entry_computation()->root_instruction(),
723               Not(op::Fusion()));
724 }
725 
726 struct GatherLoopFusionTestSpec {
727   std::string test_name;
728   std::string hlo_computation_text;
729 
Namexla::cpu::__anon8130c97e0111::GatherLoopFusionTestSpec730   static std::string Name(
731       const ::testing::TestParamInfo<GatherLoopFusionTestSpec>& info) {
732     return info.param.test_name;
733   }
734 };
735 
736 class GatherLoopFusionTest
737     : public OpcodeFusionTest,
738       public ::testing::WithParamInterface<GatherLoopFusionTestSpec> {};
739 
TEST_P(GatherLoopFusionTest,GatherLoopFusion)740 TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
741   const GatherLoopFusionTestSpec& spec = GetParam();
742   std::string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n",
743                                         spec.hlo_computation_text);
744   TF_ASSERT_OK_AND_ASSIGN(auto module,
745                           ParseAndReturnVerifiedModule(hlo_string));
746 
747   RunFusionAndCheckOpcodesWereFused(
748       module.get(),
749       {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast,
750        HloOpcode::kConstant, HloOpcode::kParameter, HloOpcode::kParameter});
751 }
752 
GetGatherLoopFusionTestSpecs()753 std::vector<GatherLoopFusionTestSpec> GetGatherLoopFusionTestSpecs() {
754   std::vector<GatherLoopFusionTestSpec> result;
755 
756   result.push_back({"FusedTensorFlowGatherV2", R"(
757 ENTRY main {
758   operand = s32[3,3] parameter(0)
759   indices = s32[2] parameter(1)
760   gather = s32[3,2] gather(operand, indices),
761       offset_dims={0},
762       collapsed_slice_dims={1},
763       start_index_map={1},
764       index_vector_dim=1,
765       slice_sizes={3, 1}
766   one = s32[] constant(1)
767   one_broadcasted = s32[3,2] broadcast(one), dimensions={}
768   ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
769 }
770 )"});
771 
772   result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"(
773 ENTRY main {
774   operand = s32[3,3] parameter(0)
775   indices = s32[2,2] parameter(1)
776   gather = s32[2,3,2] gather(operand, indices),
777       offset_dims={1},
778       collapsed_slice_dims={1},
779       start_index_map={1},
780       index_vector_dim=2,
781       slice_sizes={3, 1}
782   one = s32[] constant(1)
783   one_broadcasted = s32[2,3,2] broadcast(one), dimensions={}
784   ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
785 }
786 )"});
787 
788   result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"(
789 ENTRY main {
790   operand = s32[3,3] parameter(0)
791   indices = s32[2,2,2] parameter(1)
792   gather = s32[2,2] gather(operand, indices),
793       offset_dims={},
794       collapsed_slice_dims={0,1},
795       start_index_map={0,1},
796       index_vector_dim=2,
797       slice_sizes={1, 1}
798   one = s32[] constant(1)
799   one_broadcasted = s32[2,2] broadcast(one), dimensions={}
800   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
801 }
802 )"});
803 
804   result.push_back({"FusedTensorFlowGatherNd_0", R"(
805 ENTRY main {
806   operand = s32[3,3,2] parameter(0)
807   indices = s32[2,2] parameter(1)
808   gather = s32[2,2] gather(operand, indices),
809       offset_dims={1},
810       collapsed_slice_dims={0,1},
811       start_index_map={0,1},
812       index_vector_dim=1,
813       slice_sizes={1,1,2}
814   one = s32[] constant(1)
815   one_broadcasted = s32[2,2] broadcast(one), dimensions={}
816   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
817 }
818 )"});
819 
820   result.push_back({"FusedTensorFlowGatherNd_1", R"(
821 ENTRY main {
822   operand = s32[3,3,2] parameter(0)
823   indices = s32[2,2] parameter(1)
824   gather = s32[2,2] gather(operand, indices),
825       offset_dims={1},
826       collapsed_slice_dims={0,1},
827       start_index_map={0,1},
828       index_vector_dim=0,
829       slice_sizes={1,1,2}
830   one = s32[] constant(1)
831   one_broadcasted = s32[2,2] broadcast(one), dimensions={}
832   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
833 }
834 )"});
835 
836   result.push_back({"FusedDynamicSlice", R"(
837 ENTRY main {
838   operand = s32[3,3] parameter(0)
839   indices = s32[2] parameter(1)
840   gather = s32[1,1] gather(operand, indices),
841       offset_dims={0,1},
842       collapsed_slice_dims={},
843       start_index_map={0,1},
844       index_vector_dim=0,
845       slice_sizes={1,1}
846   one = s32[] constant(1)
847   one_broadcasted = s32[1,1] broadcast(one), dimensions={}
848   ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
849 }
850 )"});
851 
852   result.push_back({"FusedBatchDynamicSlice", R"(
853 ENTRY main {
854   operand = s32[3,3] parameter(0)
855   indices = s32[2,2] parameter(1)
856   gather = s32[2,1,1] gather(operand, indices),
857       offset_dims={1,2},
858       collapsed_slice_dims={},
859       start_index_map={0,1},
860       index_vector_dim=0,
861       slice_sizes={1,1}
862   one = s32[] constant(1)
863   one_broadcasted = s32[2,1,1] broadcast(one), dimensions={}
864   ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
865 }
866 )"});
867 
868   return result;
869 }
870 
871 INSTANTIATE_TEST_SUITE_P(GatherLoopFusionTestInstantiation,
872                          GatherLoopFusionTest,
873                          ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()),
874                          GatherLoopFusionTestSpec::Name);
875 
TEST_F(InstructionFusionTest,NoFuseReduceMajor)876 TEST_F(InstructionFusionTest, NoFuseReduceMajor) {
877   absl::string_view module_string = R"(
878 HloModule module
879 
880 add {
881   lhs = f32[] parameter(0)
882   rhs = f32[] parameter(1)
883   ROOT add = f32[] add(lhs, rhs)
884 }
885 
886 ENTRY main {
887   a = f32[50,60]{1,0} parameter(0)
888   b = f32[50,60]{1,0} parameter(1)
889   c = f32[50,60]{1,0} add(a, b)
890   init = f32[] constant(0)
891   ROOT r = f32[60]{0} reduce(c, init), dimensions={0}, to_apply=add
892 }
893 )";
894 
895   TF_ASSERT_OK_AND_ASSIGN(auto module,
896                           ParseAndReturnVerifiedModule(module_string));
897   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
898                           CpuInstructionFusion().Run(module.get()));
899   EXPECT_FALSE(fused_something);
900   EXPECT_THAT(module->entry_computation()->root_instruction(),
901               Not(op::Fusion()));
902 }
903 
TEST_F(InstructionFusionTest,FuseReduceMinor)904 TEST_F(InstructionFusionTest, FuseReduceMinor) {
905   absl::string_view module_string = R"(
906 HloModule module
907 
908 add {
909   lhs = f32[] parameter(0)
910   rhs = f32[] parameter(1)
911   ROOT add = f32[] add(lhs, rhs)
912 }
913 
914 ENTRY main {
915   a = f32[50,60]{1,0} parameter(0)
916   b = f32[50,60]{1,0} parameter(1)
917   c = f32[50,60]{1,0} add(a, b)
918   init = f32[] constant(0)
919   ROOT r = f32[] reduce(c, init), dimensions={0,1}, to_apply=add
920 }
921 )";
922 
923   TF_ASSERT_OK_AND_ASSIGN(auto module,
924                           ParseAndReturnVerifiedModule(module_string));
925   TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
926                           CpuInstructionFusion().Run(module.get()));
927   EXPECT_TRUE(fused_something);
928   EXPECT_THAT(module->entry_computation()->root_instruction(), op::Fusion());
929 }
930 }  // namespace
931 }  // namespace cpu
932 }  // namespace xla
933