1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
17
18 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/tests/filecheck.h"
26
27 namespace xla {
28 namespace gpu {
29 namespace {
30
31 namespace op = xla::testing::opcode_matchers;
32
33 class HorizontalInputFusionTest : public GpuCodegenTest {};
34
TEST_F(HorizontalInputFusionTest,BasicTest)35 TEST_F(HorizontalInputFusionTest, BasicTest) {
36 auto module = ParseAndReturnVerifiedModule(R"(
37 HloModule BasicTest
38
39 %add_f16 {
40 %x = f16[] parameter(0)
41 %y = f16[] parameter(1)
42 ROOT %add = f16[] add(%x, %y)
43 }
44
45 fused_computation.1 {
46 arg.1 = f16[1024]{0} parameter(0)
47 constant0 = f16[] constant(0)
48 ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
49 }
50
51 fused_computation.2 {
52 arg.1 = f16[1024]{0} parameter(0)
53 constant0 = f16[] constant(0)
54 ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
55 }
56
57 ENTRY entry_computation {
58 arg.1 = f16[1024]{0} parameter(0)
59 arg.2 = f16[1024]{0} parameter(1)
60 fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
61 fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
62 ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
63 }
64 )")
65 .ValueOrDie();
66
67 EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
68
69 const HloInstruction* entry_root =
70 module->entry_computation()->root_instruction();
71 EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())),
72 (op::GetTupleElement(op::Fusion()))));
73
74 const HloInstruction* fusion = entry_root->operand(0)->operand(0);
75 ASSERT_TRUE(fusion->IsMultiOutputFusion());
76 EXPECT_THAT(fusion->fused_expression_root(),
77 op::Tuple(op::Reduce(), op::Reduce()));
78 }
79
TEST_F(HorizontalInputFusionTest,ManyInputFusions)80 TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
81 auto module = CreateNewVerifiedModule();
82
83 HloComputation* reduce_computation;
84 {
85 auto embedded_builder = HloComputation::Builder("add");
86 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
87 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
88 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
89 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
90 embedded_builder.AddInstruction(
91 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
92 reduce_computation =
93 module->AddEmbeddedComputation(embedded_builder.Build());
94 }
95
96 HloComputation::Builder builder(TestName());
97 std::vector<HloInstruction*> var_outs;
98 auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
99 auto output_shape = ShapeUtil::MakeShape(F32, {1024});
100 for (int64_t i = 0; i < 130; ++i) {
101 // %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
102 // f32[1024] {
103 // %param_0 = f32[1024,1024]{1,0} parameter(0)
104 // %param_1 = f32[] parameter(1)
105 // %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
106 // dimensions={}
107 // %multiply = f32[1024,1024]{1,0}
108 // multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
109 // %broadcast)
110 // %constant0 = f32[] constant(0)
111 // ROOT %reduce = f32[1024]{0}
112 // reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
113 // dimensions={1}, to_apply=%add
114 // }
115 HloInstruction* param_var_in = builder.AddInstruction(
116 HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
117 HloInstruction* param_alpha =
118 builder.AddInstruction(HloInstruction::CreateParameter(
119 i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
120 auto alpha_broadcasted = builder.AddInstruction(
121 HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
122 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
123 input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
124 HloInstruction* const0 = builder.AddInstruction(
125 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
126 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
127 output_shape, mul, const0, {1}, reduce_computation));
128 var_outs.push_back(reduce);
129 }
130 builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
131 module->AddEntryComputation(builder.Build());
132
133 // Verify that horizontal fusion is kicked in. Check that there are multiple
134 // `reduce` instructions fused into the same fusion. 6 is just a randomly
135 // picked number as we don't exactly know how large the fusion will be
136 // created due to the `FusionFitsInBudget` constraint.
137 CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
138 /*match_optimized_ir=*/false);
139
140 // Testing with the entire gpu optimization pipeline.
141 EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
142 }
143
TEST_F(HorizontalInputFusionTest,MultiOutputFusionTest)144 TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
145 // This tests the below pattern. One known issue is that gtes (to fusions) can
146 // be removed after their producer fusions are merged. In the below case, gte2
147 // and gte6 will be gone if Fusion2 is fused into Fusion1.
148 //
149 // Fusion1 Fusion2
150 // | | | |
151 // | gte1 gte2 |
152 // | | | |
153 // | Fusion3 |
154 // | | | |
155 // gte3 gte4 gte5 gte6
156 // \ | | /
157 // =====ROOT=====
158 //
159 auto module = ParseAndReturnVerifiedModule(R"(
160 HloModule MultiOutputFusionTest
161
162 %add_f16 {
163 %x = f16[] parameter(0)
164 %y = f16[] parameter(1)
165 ROOT %add = f16[] add(%x, %y)
166 }
167
168 fused_computation.1 {
169 arg.1 = f16[1024]{0} parameter(0)
170 constant0 = f16[] constant(0)
171 reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
172 add.0 = f16[1024] add(arg.1, arg.1)
173 ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
174 }
175
176 fused_computation.2 {
177 arg.1 = f16[1024]{0} parameter(0)
178 constant0 = f16[] constant(0)
179 reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
180 add.0 = f16[1024] add(arg.1, arg.1)
181 ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
182 }
183
184 fused_computation.3 {
185 arg.0 = f16[1024]{0} parameter(0)
186 arg.1 = f16[1024]{0} parameter(1)
187 add.0 = f16[1024] add(arg.0, arg.1)
188 mul.0 = f16[1024] multiply(arg.0, arg.1)
189 ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
190 }
191
192 ENTRY entry_computation {
193 arg.1 = f16[1024]{0} parameter(0)
194 arg.2 = f16[1024]{0} parameter(1)
195 fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
196 fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
197 gte.3 = f16[] get-tuple-element(fusion.1), index=0
198 gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
199 gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
200 gte.6 = f16[] get-tuple-element(fusion.2), index=0
201 fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
202 kind=kLoop, calls=fused_computation.3
203 gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
204 gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
205 ROOT tuple.1 = (f16[], f16[1024], f16[1024]{0}, f16[])
206 tuple(gte.3, gte.4, gte.5, gte.6)
207 }
208 )")
209 .ValueOrDie();
210
211 EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
212 }
213
TEST_F(HorizontalInputFusionTest,NonfusionInstrs)214 TEST_F(HorizontalInputFusionTest, NonfusionInstrs) {
215 auto module = ParseAndReturnVerifiedModule(R"(
216 HloModule NonfusionInstrs
217
218 %add_f16 {
219 %x = f16[] parameter(0)
220 %y = f16[] parameter(1)
221 ROOT %add = f16[] add(%x, %y)
222 }
223
224 ENTRY entry_computation {
225 arg.0 = f16[1024]{0} parameter(0)
226 arg.1 = f16[1024]{0} parameter(1)
227 constant0 = f16[] constant(0)
228 reduce.0 = f16[] reduce(arg.0, constant0), dimensions={0}, to_apply=%add_f16
229 reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
230 ROOT tuple.0 = (f16[], f16[]) tuple(reduce.0, reduce.1)
231 }
232 )")
233 .ValueOrDie();
234
235 EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
236
237 const HloInstruction* entry_root =
238 module->entry_computation()->root_instruction();
239 EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())),
240 (op::GetTupleElement(op::Fusion()))));
241
242 const HloInstruction* fusion = entry_root->operand(0)->operand(0);
243 ASSERT_TRUE(fusion->IsMultiOutputFusion());
244 EXPECT_THAT(fusion->fused_expression_root(),
245 op::Tuple(op::Reduce(), op::Reduce()));
246 }
247
248 } // namespace
249 } // namespace gpu
250 } // namespace xla
251