• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/test.h"
31 
32 namespace xla {
33 namespace cpu {
34 namespace {
35 
36 class CpuFusionTest : public HloTestBase {
37  protected:
CpuFusionTest()38   CpuFusionTest() {}
39 
40   ErrorSpec error_spec_{0.0001, 1e-5};
41 
42  private:
GetDebugOptionsForTest()43   DebugOptions GetDebugOptionsForTest() override {
44     DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest();
45     debug_options.add_xla_disable_hlo_passes("layout-assignment");
46     return debug_options;
47   }
48 };
49 
TEST_F(CpuFusionTest,FuseTwoElementwiseOps)50 TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
51   auto builder = HloComputation::Builder(TestName());
52   auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
53   auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
54   Shape vshape = input_literal1.shape();
55 
56   auto input1 = builder.AddInstruction(
57       HloInstruction::CreateConstant(std::move(input_literal1)));
58   auto input2 = builder.AddInstruction(
59       HloInstruction::CreateConstant(std::move(input_literal2)));
60 
61   auto add1 = builder.AddInstruction(
62       HloInstruction::CreateBinary(vshape, HloOpcode::kAdd, input1, input2));
63   builder.AddInstruction(
64       HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1));
65 
66   auto module = CreateNewVerifiedModule();
67   module->AddEntryComputation(builder.Build());
68 
69   CpuInstructionFusion fusion;
70   EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
71 
72   // The computation root instruction was fused. Verify the fusion instruction
73   // is now the root.
74   auto computation = module->entry_computation();
75   auto fusion_instruction = computation->root_instruction();
76   EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
77   EXPECT_EQ(HloOpcode::kNegate,
78             fusion_instruction->fused_expression_root()->opcode());
79   // There should be four fused instructions: 2 parameters, the add, and the
80   // negate.
81   EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
82 
83   // Compile and execute the computation.
84   auto result = ExecuteAndTransfer(module->Clone(), {});
85 
86   // Check the output correctness.
87   LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
88 }
89 
TEST_F(CpuFusionTest,FuseElementwiseOpChain)90 TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
91   auto builder = HloComputation::Builder(TestName());
92   auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
93   Shape vshape = input_literal.shape();
94 
95   auto input = builder.AddInstruction(
96       HloInstruction::CreateConstant(std::move(input_literal)));
97   auto negate = builder.AddInstruction(
98       HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input));
99   auto ceil = builder.AddInstruction(
100       HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
101   auto exp = builder.AddInstruction(
102       HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil));
103   auto floor = builder.AddInstruction(
104       HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp));
105   auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
106       vshape,
107       builder.AddInstruction(
108           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
109       {}));
110   builder.AddInstruction(
111       HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor));
112 
113   auto module = CreateNewVerifiedModule();
114   module->AddEntryComputation(builder.Build());
115 
116   CpuInstructionFusion fusion;
117   EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
118 
119   // The computation root instruction was fused. Verify the fusion instruction
120   // is now the root.
121   auto computation = module->entry_computation();
122   auto fusion_instruction = computation->root_instruction();
123   EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
124   EXPECT_EQ(HloOpcode::kMultiply,
125             fusion_instruction->fused_expression_root()->opcode());
126   // There should be 8 fused instructions: 2 parameters and the fused
127   // operations.
128   EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
129 
130   // Compile and execute the computation.
131   auto result = ExecuteAndTransfer(module->Clone(), {});
132 
133   // Check the output correctness.
134   LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
135 }
136 
TEST_F(CpuFusionTest,ElementwiseOpChainWithNonfusibleInstruction)137 TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
138   // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the
139   // middle.
140   auto module = CreateNewVerifiedModule();
141   auto builder = HloComputation::Builder(TestName());
142   auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
143   Shape vshape = input_literal.shape();
144 
145   auto input = builder.AddInstruction(
146       HloInstruction::CreateConstant(std::move(input_literal)));
147   auto negate = builder.AddInstruction(
148       HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input));
149   auto ceil = builder.AddInstruction(
150       HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
151 
152   auto cshape = ShapeUtil::MakeShape(F32, {6});
153   auto concatenate = builder.AddInstruction(
154       HloInstruction::CreateConcatenate(cshape, {ceil, ceil}, /*dimension=*/0));
155 
156   // Build an x+y computation to use in a reduce.
157   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
158   auto embedded_builder = HloComputation::Builder("f32+f32");
159   embedded_builder.AddInstruction(HloInstruction::CreateBinary(
160       r0f32, HloOpcode::kAdd,
161       embedded_builder.AddInstruction(
162           HloInstruction::CreateParameter(0, r0f32, "x")),
163       embedded_builder.AddInstruction(
164           HloInstruction::CreateParameter(1, r0f32, "y"))));
165   auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
166 
167   // This is a nop reduction.
168   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
169       cshape,
170       builder.AddInstruction(HloInstruction::CreateReshape(
171           ShapeUtil::MakeShape(F32, {1, 6}), concatenate)),
172       /*init_value=*/
173       builder.AddInstruction(
174           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
175       /*dimensions_to_reduce=*/{0}, add_f32));
176 
177   auto exp = builder.AddInstruction(
178       HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
179   auto floor = builder.AddInstruction(
180       HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp));
181   auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
182       cshape,
183       builder.AddInstruction(
184           HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
185       {}));
186   builder.AddInstruction(
187       HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor));
188 
189   module->AddEntryComputation(builder.Build());
190 
191   CpuInstructionFusion fusion;
192   EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
193 
194   // The computation root instruction was fused. Verify the fusion instruction
195   // is now the root.
196   auto computation = module->entry_computation();
197 
198   auto fusion_instruction1 = computation->root_instruction();
199   EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
200   EXPECT_EQ(HloOpcode::kMultiply,
201             fusion_instruction1->fused_expression_root()->opcode());
202   // There should be 6 fused instructions in the root fusion instruction: 2
203   // parameters, multiply, floor, and exp.
204   EXPECT_EQ(6, fusion_instruction1->fused_instruction_count())
205       << fusion_instruction1->fused_instructions_computation()->ToString();
206 
207   auto fusion_instruction2 = reduce->operand(0);
208   EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
209   EXPECT_EQ(HloOpcode::kReshape,
210             fusion_instruction2->fused_expression_root()->opcode());
211   // There should be 5 fused instructions in the second fusion instruction: 1
212   // parameter, negate, ceil, concat, and reshape.
213   EXPECT_EQ(5, fusion_instruction2->fused_instruction_count())
214       << fusion_instruction2->fused_instructions_computation()->ToString();
215 
216   // Compile and execute the computation.
217   auto result = ExecuteAndTransfer(module->Clone(), {});
218 
219   // Check the output correctness.
220   LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
221                                        result, error_spec_);
222 }
223 
TEST_F(CpuFusionTest,TestOperandOrderToAvoidDuplication)224 TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
225   // Test that the operands of an instruction to be fused are considered in the
226   // proper order to avoid duplication. Test input:
227   //
228   //   constant = {...}
229   //   negate    = neg(constant)
230   //   ceil      = ceil(negate)
231   //   add1      = add(negate, ceil)
232   //   add2      = add(ceil, negate)
233   //
234   // In this example, the operands of both add1 and add2 should be fused in the
235   // order {ceil, negate} even though they have different orders in their
236   // operand vectors. Test for this problem by counting the number of nodes in
237   // each fusion instruction to ensure that negate is not duplicated.
238   auto builder = HloComputation::Builder(TestName());
239   auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
240   Shape vshape = input_literal.shape();
241 
242   auto constant = builder.AddInstruction(
243       HloInstruction::CreateConstant(std::move(input_literal)));
244   auto negate = builder.AddInstruction(
245       HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, constant));
246   auto ceil = builder.AddInstruction(
247       HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate));
248 
249   auto add1 = builder.AddInstruction(
250       HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, negate, ceil));
251   auto add2 = builder.AddInstruction(
252       HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, ceil, negate));
253 
254   // Tie together the two adds with a tuple to create a single root.
255   auto result =
256       builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
257 
258   // Create computation and module.
259   auto module = CreateNewVerifiedModule();
260   module->AddEntryComputation(builder.Build());
261 
262   // Run fusion.
263   CpuInstructionFusion fusion;
264   EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
265 
266   auto fusion1 = result->operand(0);
267   auto fusion2 = result->operand(1);
268   EXPECT_EQ(HloOpcode::kFusion, fusion1->opcode());
269   EXPECT_EQ(HloOpcode::kFusion, fusion2->opcode());
270 
271   // Each fusion instruction should have 4 fused instruction inside: add, ceil,
272   // negate, and the fused parameter.
273   EXPECT_EQ(4, fusion1->fused_instruction_count());
274   EXPECT_EQ(4, fusion2->fused_instruction_count());
275 
276   // The fusion has no parameters, everything is fused including constants.
277   EXPECT_EQ(0, fusion1->operand_count());
278   EXPECT_EQ(0, fusion2->operand_count());
279 }
280 
TEST_F(CpuFusionTest,DoNotDuplicateExpensiveOps)281 TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
282   // Verify that expensive operations will not be fused if the fusion results in
283   // duplication. Test code:
284   //
285   //   constant = 42.0
286   //   exp1 = exp(constant)
287   //   negate1 = negate(exp1)
288   //   exp2 = exp(constant)
289   //   negate2 = negate(exp2)
290   //   tuple = tuple(negate1, negate2, exp2)
291   //
292   // exp1 should be fused down into negate1, but exp2 will not be fused into
293   // negate2 because this will result in duplication of the expensive exp
294   // computation. The duplication is caused by the other use of exp2 in the
295   // tuple.
296   auto builder = HloComputation::Builder(TestName());
297   auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
298   auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
299   auto constant = builder.AddInstruction(
300       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
301   Shape shape = constant->shape();
302 
303   auto exp1 = builder.AddInstruction(
304       HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant));
305   auto negate1 = builder.AddInstruction(
306       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp1));
307 
308   auto exp2 = builder.AddInstruction(
309       HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant));
310   auto negate2 = builder.AddInstruction(
311       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp2));
312 
313   auto tuple = builder.AddInstruction(
314       HloInstruction::CreateTuple({negate1, negate2, exp2}));
315 
316   auto module = CreateNewVerifiedModule();
317   module->AddEntryComputation(builder.Build());
318 
319   CpuInstructionFusion fusion;
320   EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
321 
322   // The only fusion instruction should be operand 0 of the tuple (formerly
323   // negate1).
324   EXPECT_EQ(HloOpcode::kFusion, tuple->operand(0)->opcode());
325   EXPECT_EQ(HloOpcode::kNegate, tuple->operand(1)->opcode());
326   EXPECT_EQ(HloOpcode::kExp, tuple->operand(2)->opcode());
327 
328   auto fusion_inst = tuple->operand(0);
329   // There should be three fused instructions: negate2, exp2, and the fused
330   // constant.
331   EXPECT_EQ(3, fusion_inst->fused_instruction_count());
332   EXPECT_EQ(0, fusion_inst->operand_count());
333 }
334 
335 }  // namespace
336 }  // namespace cpu
337 }  // namespace xla
338