• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
17 
18 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/filecheck.h"
26 
27 namespace xla {
28 namespace gpu {
29 namespace {
30 
31 namespace op = xla::testing::opcode_matchers;
32 
33 class HorizontalInputFusionTest : public GpuCodegenTest {};
34 
TEST_F(HorizontalInputFusionTest,BasicTest)35 TEST_F(HorizontalInputFusionTest, BasicTest) {
36   auto module = ParseAndReturnVerifiedModule(R"(
37  HloModule BasicTest
38 
39   %add_f16 {
40     %x = f16[] parameter(0)
41     %y = f16[] parameter(1)
42     ROOT %add = f16[] add(%x, %y)
43   }
44 
45  fused_computation.1 {
46    arg.1 = f16[1024]{0} parameter(0)
47    constant0 = f16[] constant(0)
48    ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
49  }
50 
51  fused_computation.2 {
52    arg.1 = f16[1024]{0} parameter(0)
53    constant0 = f16[] constant(0)
54    ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
55  }
56 
57  ENTRY entry_computation {
58    arg.1 = f16[1024]{0} parameter(0)
59    arg.2 = f16[1024]{0} parameter(1)
60    fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
61    fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
62    ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
63  }
64 )")
65                     .ValueOrDie();
66 
67   EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
68 
69   const HloInstruction* entry_root =
70       module->entry_computation()->root_instruction();
71   EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())),
72                                     (op::GetTupleElement(op::Fusion()))));
73 
74   const HloInstruction* fusion = entry_root->operand(0)->operand(0);
75   ASSERT_TRUE(fusion->IsMultiOutputFusion());
76   EXPECT_THAT(fusion->fused_expression_root(),
77               op::Tuple(op::Reduce(), op::Reduce()));
78 }
79 
TEST_F(HorizontalInputFusionTest,ManyInputFusions)80 TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
81   auto module = CreateNewVerifiedModule();
82 
83   HloComputation* reduce_computation;
84   {
85     auto embedded_builder = HloComputation::Builder("add");
86     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
87         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
88     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
89         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
90     embedded_builder.AddInstruction(
91         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
92     reduce_computation =
93         module->AddEmbeddedComputation(embedded_builder.Build());
94   }
95 
96   HloComputation::Builder builder(TestName());
97   std::vector<HloInstruction*> var_outs;
98   auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
99   auto output_shape = ShapeUtil::MakeShape(F32, {1024});
100   for (int64_t i = 0; i < 130; ++i) {
101     // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
102     // f32[1024] {
103     //  %param_0 = f32[1024,1024]{1,0} parameter(0)
104     //  %param_1 = f32[] parameter(1)
105     //  %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
106     //  dimensions={}
107     //  %multiply = f32[1024,1024]{1,0}
108     //      multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
109     //      %broadcast)
110     //  %constant0 = f32[] constant(0)
111     //  ROOT %reduce = f32[1024]{0}
112     //      reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
113     //          dimensions={1}, to_apply=%add
114     // }
115     HloInstruction* param_var_in = builder.AddInstruction(
116         HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
117     HloInstruction* param_alpha =
118         builder.AddInstruction(HloInstruction::CreateParameter(
119             i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
120     auto alpha_broadcasted = builder.AddInstruction(
121         HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
122     auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
123         input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
124     HloInstruction* const0 = builder.AddInstruction(
125         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
126     auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
127         output_shape, mul, const0, {1}, reduce_computation));
128     var_outs.push_back(reduce);
129   }
130   builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
131   module->AddEntryComputation(builder.Build());
132 
133   // Verify that horizontal fusion is kicked in. Check that there are multiple
134   // `reduce` instructions fused into the same fusion. 6 is just a randomly
135   // picked number as we don't exactly know how large the fusion will be
136   // created due to the `FusionFitsInBudget` constraint.
137   CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
138                      /*match_optimized_ir=*/false);
139 
140   // Testing with the entire gpu optimization pipeline.
141   EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
142 }
143 
TEST_F(HorizontalInputFusionTest,MultiOutputFusionTest)144 TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
145   // This tests the below pattern. One known issue is that gtes (to fusions) can
146   // be removed after their producer fusions are merged. In the below case, gte2
147   // and gte6 will be gone if Fusion2 is fused into Fusion1.
148   //
149   // Fusion1   Fusion2
150   //  |   |    |     |
151   //  |  gte1 gte2   |
152   //  |   |    |     |
153   //  |   Fusion3    |
154   //  |    |   |     |
155   // gte3 gte4 gte5 gte6
156   //  \  |     |    /
157   //  =====ROOT=====
158   //
159   auto module = ParseAndReturnVerifiedModule(R"(
160  HloModule MultiOutputFusionTest
161 
162   %add_f16 {
163     %x = f16[] parameter(0)
164     %y = f16[] parameter(1)
165     ROOT %add = f16[] add(%x, %y)
166   }
167 
168  fused_computation.1 {
169    arg.1 = f16[1024]{0} parameter(0)
170    constant0 = f16[] constant(0)
171    reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
172    add.0 = f16[1024] add(arg.1, arg.1)
173    ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
174  }
175 
176  fused_computation.2 {
177    arg.1 = f16[1024]{0} parameter(0)
178    constant0 = f16[] constant(0)
179    reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
180    add.0 = f16[1024] add(arg.1, arg.1)
181    ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
182  }
183 
184  fused_computation.3 {
185    arg.0 = f16[1024]{0} parameter(0)
186    arg.1 = f16[1024]{0} parameter(1)
187    add.0 = f16[1024] add(arg.0, arg.1)
188    mul.0 = f16[1024] multiply(arg.0, arg.1)
189    ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
190  }
191 
192  ENTRY entry_computation {
193    arg.1 = f16[1024]{0} parameter(0)
194    arg.2 = f16[1024]{0} parameter(1)
195    fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
196    fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
197    gte.3 = f16[] get-tuple-element(fusion.1), index=0
198    gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
199    gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
200    gte.6 = f16[] get-tuple-element(fusion.2), index=0
201    fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
202        kind=kLoop, calls=fused_computation.3
203    gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
204    gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
205    ROOT tuple.1 = (f16[], f16[1024], f16[1024]{0}, f16[])
206        tuple(gte.3, gte.4, gte.5, gte.6)
207  }
208 )")
209                     .ValueOrDie();
210 
211   EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
212 }
213 
TEST_F(HorizontalInputFusionTest,NonfusionInstrs)214 TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
215   auto module = ParseAndReturnVerifiedModule(R"(
216  HloModule NonfusionInstrs
217 
218  %add_f16 {
219    %x = f16[] parameter(0)
220    %y = f16[] parameter(1)
221    ROOT %add = f16[] add(%x, %y)
222  }
223 
224  ENTRY entry_computation {
225    arg.0 = f16[1024]{0} parameter(0)
226    arg.1 = f16[1024]{0} parameter(1)
227    constant0 = f16[] constant(0)
228    reduce.0 = f16[] reduce(arg.0, constant0), dimensions={0}, to_apply=%add_f16
229    reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
230    ROOT tuple.0 = (f16[], f16[]) tuple(reduce.0, reduce.1)
231  }
232 )")
233                     .ValueOrDie();
234 
235   EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
236 
237   const HloInstruction* entry_root =
238       module->entry_computation()->root_instruction();
239   EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())),
240                                     (op::GetTupleElement(op::Fusion()))));
241 
242   const HloInstruction* fusion = entry_root->operand(0)->operand(0);
243   ASSERT_TRUE(fusion->IsMultiOutputFusion());
244   EXPECT_THAT(fusion->fused_expression_root(),
245               op::Tuple(op::Reduce(), op::Reduce()));
246 }
247 
248 }  // namespace
249 }  // namespace gpu
250 }  // namespace xla
251