1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/test.h"
31
32 namespace xla {
33 namespace cpu {
34 namespace {
35
36 class CpuFusionTest : public HloTestBase {
37 protected:
CpuFusionTest()38 CpuFusionTest() {}
39
40 ErrorSpec error_spec_{0.0001, 1e-5};
41
42 private:
GetDebugOptionsForTest()43 DebugOptions GetDebugOptionsForTest() override {
44 DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
45 debug_options.add_xla_disable_hlo_passes("layout-assignment");
46 return debug_options;
47 }
48 };
49
TEST_F(CpuFusionTest,FuseTwoElementwiseOps)50 TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
51 auto builder = HloComputation::Builder(TestName());
52 auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
53 auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
54 Shape vshape = input_literal1.shape();
55
56 auto input1 = builder.AddInstruction(
57 HloInstruction::CreateConstant(std::move(input_literal1)));
58 auto input2 = builder.AddInstruction(
59 HloInstruction::CreateConstant(std::move(input_literal2)));
60
61 auto add1 = builder.AddInstruction(
62 HloInstruction::CreateBinary(vshape, HloOpcode::kAdd, input1, input2));
63 builder.AddInstruction(
64 HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1));
65
66 auto module = CreateNewVerifiedModule();
67 module->AddEntryComputation(builder.Build());
68
69 CpuInstructionFusion fusion;
70 EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
71
72 // The computation root instruction was fused. Verify the fusion instruction
73 // is now the root.
74 auto computation = module->entry_computation();
75 auto fusion_instruction = computation->root_instruction();
76 EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
77 EXPECT_EQ(HloOpcode::kNegate,
78 fusion_instruction->fused_expression_root()->opcode());
79 // There should be four fused instructions: 2 parameters, the add, and the
80 // negate.
81 EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
82
83 // Compile and execute the computation.
84 auto result = ExecuteAndTransfer(module->Clone(), {});
85
86 // Check the output correctness.
87 LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
88 }
89
TEST_F(CpuFusionTest,FuseElementwiseOpChain)90 TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
91 auto builder = HloComputation::Builder(TestName());
92 auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
93 Shape vshape = input_literal.shape();
94
95 auto input = builder.AddInstruction(
96 HloInstruction::CreateConstant(std::move(input_literal)));
97 auto negate = builder.AddInstruction(
98 HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input));
99 auto ceil = builder.AddInstruction(
100 HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
101 auto exp = builder.AddInstruction(
102 HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil));
103 auto floor = builder.AddInstruction(
104 HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp));
105 auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
106 vshape,
107 builder.AddInstruction(
108 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
109 {}));
110 builder.AddInstruction(
111 HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor));
112
113 auto module = CreateNewVerifiedModule();
114 module->AddEntryComputation(builder.Build());
115
116 CpuInstructionFusion fusion;
117 EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
118
119 // The computation root instruction was fused. Verify the fusion instruction
120 // is now the root.
121 auto computation = module->entry_computation();
122 auto fusion_instruction = computation->root_instruction();
123 EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
124 EXPECT_EQ(HloOpcode::kMultiply,
125 fusion_instruction->fused_expression_root()->opcode());
126 // There should be 8 fused instructions: 2 parameters and the fused
127 // operations.
128 EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
129
130 // Compile and execute the computation.
131 auto result = ExecuteAndTransfer(module->Clone(), {});
132
133 // Check the output correctness.
134 LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
135 }
136
TEST_F(CpuFusionTest,ElementwiseOpChainWithNonfusibleInstruction)137 TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
138 // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the
139 // middle.
140 auto module = CreateNewVerifiedModule();
141 auto builder = HloComputation::Builder(TestName());
142 auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
143 Shape vshape = input_literal.shape();
144
145 auto input = builder.AddInstruction(
146 HloInstruction::CreateConstant(std::move(input_literal)));
147 auto negate = builder.AddInstruction(
148 HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input));
149 auto ceil = builder.AddInstruction(
150 HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
151
152 auto cshape = ShapeUtil::MakeShape(F32, {6});
153 auto concatenate = builder.AddInstruction(
154 HloInstruction::CreateConcatenate(cshape, {ceil, ceil}, /*dimension=*/0));
155
156 // Build an x+y computation to use in a reduce.
157 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
158 auto embedded_builder = HloComputation::Builder("f32+f32");
159 embedded_builder.AddInstruction(HloInstruction::CreateBinary(
160 r0f32, HloOpcode::kAdd,
161 embedded_builder.AddInstruction(
162 HloInstruction::CreateParameter(0, r0f32, "x")),
163 embedded_builder.AddInstruction(
164 HloInstruction::CreateParameter(1, r0f32, "y"))));
165 auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
166
167 // This is a nop reduction.
168 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
169 cshape,
170 builder.AddInstruction(HloInstruction::CreateReshape(
171 ShapeUtil::MakeShape(F32, {1, 6}), concatenate)),
172 /*init_value=*/
173 builder.AddInstruction(
174 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
175 /*dimensions_to_reduce=*/{0}, add_f32));
176
177 auto exp = builder.AddInstruction(
178 HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
179 auto floor = builder.AddInstruction(
180 HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp));
181 auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
182 cshape,
183 builder.AddInstruction(
184 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
185 {}));
186 builder.AddInstruction(
187 HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor));
188
189 module->AddEntryComputation(builder.Build());
190
191 CpuInstructionFusion fusion;
192 EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
193
194 // The computation root instruction was fused. Verify the fusion instruction
195 // is now the root.
196 auto computation = module->entry_computation();
197
198 auto fusion_instruction1 = computation->root_instruction();
199 EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
200 EXPECT_EQ(HloOpcode::kMultiply,
201 fusion_instruction1->fused_expression_root()->opcode());
202 // There should be 6 fused instructions in the root fusion instruction: 2
203 // parameters, multiply, floor, and exp.
204 EXPECT_EQ(6, fusion_instruction1->fused_instruction_count())
205 << fusion_instruction1->fused_instructions_computation()->ToString();
206
207 auto fusion_instruction2 = reduce->operand(0);
208 EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
209 EXPECT_EQ(HloOpcode::kReshape,
210 fusion_instruction2->fused_expression_root()->opcode());
211 // There should be 5 fused instructions in the second fusion instruction: 1
212 // parameter, negate, ceil, concat, and reshape.
213 EXPECT_EQ(5, fusion_instruction2->fused_instruction_count())
214 << fusion_instruction2->fused_instructions_computation()->ToString();
215
216 // Compile and execute the computation.
217 auto result = ExecuteAndTransfer(module->Clone(), {});
218
219 // Check the output correctness.
220 LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
221 result, error_spec_);
222 }
223
TEST_F(CpuFusionTest,TestOperandOrderToAvoidDuplication)224 TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
225 // Test that the operands of an instruction to be fused are considered in the
226 // proper order to avoid duplication. Test input:
227 //
228 // constant = {...}
229 // negate = neg(constant)
230 // ceil = ceil(negate)
231 // add1 = add(negate, ceil)
232 // add2 = add(ceil, negate)
233 //
234 // In this example, the operands of both add1 and add2 should be fused in the
235 // order {ceil, negate} even though they have different orders in their
236 // operand vectors. Test for this problem by counting the number of nodes in
237 // each fusion instruction to ensure that negate is not duplicated.
238 auto builder = HloComputation::Builder(TestName());
239 auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
240 Shape vshape = input_literal.shape();
241
242 auto constant = builder.AddInstruction(
243 HloInstruction::CreateConstant(std::move(input_literal)));
244 auto negate = builder.AddInstruction(
245 HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, constant));
246 auto ceil = builder.AddInstruction(
247 HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
248
249 auto add1 = builder.AddInstruction(
250 HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, negate, ceil));
251 auto add2 = builder.AddInstruction(
252 HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, ceil, negate));
253
254 // Tie together the two adds with a tuple to create a single root.
255 auto result =
256 builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
257
258 // Create computation and module.
259 auto module = CreateNewVerifiedModule();
260 module->AddEntryComputation(builder.Build());
261
262 // Run fusion.
263 CpuInstructionFusion fusion;
264 EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
265
266 auto fusion1 = result->operand(0);
267 auto fusion2 = result->operand(1);
268 EXPECT_EQ(HloOpcode::kFusion, fusion1->opcode());
269 EXPECT_EQ(HloOpcode::kFusion, fusion2->opcode());
270
271 // Each fusion instruction should have 4 fused instruction inside: add, ceil,
272 // negate, and the fused parameter.
273 EXPECT_EQ(4, fusion1->fused_instruction_count());
274 EXPECT_EQ(4, fusion2->fused_instruction_count());
275
276 // The fusion has no parameters, everything is fused including constants.
277 EXPECT_EQ(0, fusion1->operand_count());
278 EXPECT_EQ(0, fusion2->operand_count());
279 }
280
TEST_F(CpuFusionTest,DoNotDuplicateExpensiveOps)281 TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
282 // Verify that expensive operations will not be fused if the fusion results in
283 // duplication. Test code:
284 //
285 // constant = 42.0
286 // exp1 = exp(constant)
287 // negate1 = negate(exp1)
288 // exp2 = exp(constant)
289 // negate2 = negate(exp2)
290 // tuple = tuple(negate1, negate2, exp2)
291 //
292 // exp1 should be fused down into negate1, but exp2 will not be fused into
293 // negate2 because this will result in duplication of the expensive exp
294 // computation. The duplication is caused by the other use of exp2 in the
295 // tuple.
296 auto builder = HloComputation::Builder(TestName());
297 auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
298 auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
299 auto constant = builder.AddInstruction(
300 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
301 Shape shape = constant->shape();
302
303 auto exp1 = builder.AddInstruction(
304 HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant));
305 auto negate1 = builder.AddInstruction(
306 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp1));
307
308 auto exp2 = builder.AddInstruction(
309 HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant));
310 auto negate2 = builder.AddInstruction(
311 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp2));
312
313 auto tuple = builder.AddInstruction(
314 HloInstruction::CreateTuple({negate1, negate2, exp2}));
315
316 auto module = CreateNewVerifiedModule();
317 module->AddEntryComputation(builder.Build());
318
319 CpuInstructionFusion fusion;
320 EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
321
322 // The only fusion instruction should be operand 0 of the tuple (formerly
323 // negate1).
324 EXPECT_EQ(HloOpcode::kFusion, tuple->operand(0)->opcode());
325 EXPECT_EQ(HloOpcode::kNegate, tuple->operand(1)->opcode());
326 EXPECT_EQ(HloOpcode::kExp, tuple->operand(2)->opcode());
327
328 auto fusion_inst = tuple->operand(0);
329 // There should be three fused instructions: negate2, exp2, and the fused
330 // constant.
331 EXPECT_EQ(3, fusion_inst->fused_instruction_count());
332 EXPECT_EQ(0, fusion_inst->operand_count());
333 }
334
335 } // namespace
336 } // namespace cpu
337 } // namespace xla
338