1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cse.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/core/platform/types.h"
41
42 namespace op = xla::testing::opcode_matchers;
43
44 namespace xla {
45 namespace {
46
47 class HloCseTest : public HloTestBase {
48 protected:
HloCseTest()49 HloCseTest() {}
50 };
51
TEST_F(HloCseTest,CombineTwoConstants)52 TEST_F(HloCseTest, CombineTwoConstants) {
53 // Test that two identical constants are commoned.
54 auto builder = HloComputation::Builder(TestName());
55 auto constant1 = builder.AddInstruction(
56 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
57 auto constant2 = builder.AddInstruction(
58 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
59 builder.AddInstruction(HloInstruction::CreateBinary(
60 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
61
62 auto module = CreateNewVerifiedModule();
63 auto computation = module->AddEntryComputation(builder.Build());
64
65 EXPECT_EQ(3, computation->instruction_count());
66
67 HloCSE cse(/*is_layout_sensitive=*/false);
68 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
69
70 EXPECT_EQ(2, computation->instruction_count());
71 HloInstruction* constant = *computation->instructions().begin();
72 EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
73
74 auto result = ExecuteAndTransfer(module->Clone(), {});
75 auto expected = LiteralUtil::CreateR0<float>(84.0);
76 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
77 }
78
TEST_F(HloCseTest,CombineTwoConstantsDifferentLayoutsAndInsensitive)79 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
80 // Test that two identical constants with different layouts are commoned if
81 // the pass is not layout sensitive.
82 auto builder = HloComputation::Builder(TestName());
83 auto constant1 = builder.AddInstruction(
84 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
85 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
86 auto constant2 = builder.AddInstruction(
87 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
88 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
89 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
90 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
91
92 auto module = CreateNewVerifiedModule();
93 auto computation = module->AddEntryComputation(builder.Build());
94
95 EXPECT_EQ(3, computation->instruction_count());
96 EXPECT_THAT(add, op::Add(constant1, constant2));
97
98 HloCSE cse(/*is_layout_sensitive=*/false);
99 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
100
101 EXPECT_EQ(2, computation->instruction_count());
102 auto first_operand = add->operand(0);
103 EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2));
104 EXPECT_THAT(add, op::Add(first_operand, first_operand));
105
106 auto result = ExecuteAndTransfer(module->Clone(), {});
107 auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
108 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
109 }
110
TEST_F(HloCseTest,CombineTwoConstantsDifferentLayoutsAndSensitive)111 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
112 // Test that two identical constants with different layouts are *not* commoned
113 // if the pass is layout sensitive.
114 auto builder = HloComputation::Builder(TestName());
115 auto constant1 = builder.AddInstruction(
116 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
117 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
118 auto constant2 = builder.AddInstruction(
119 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
120 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
121 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
122 constant1->shape(), HloOpcode::kAdd, constant1, constant2));
123
124 auto module = CreateNewVerifiedModule();
125 auto computation = module->AddEntryComputation(builder.Build());
126
127 EXPECT_EQ(3, computation->instruction_count());
128 EXPECT_THAT(add, op::Add(constant1, constant2));
129
130 HloCSE cse(/*is_layout_sensitive=*/true);
131 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
132
133 EXPECT_EQ(3, computation->instruction_count());
134 EXPECT_THAT(add, op::Add(constant1, constant2));
135
136 auto result = ExecuteAndTransfer(module->Clone(), {});
137 auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
138 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
139 }
140
TEST_F(HloCseTest,ConstantsSameValueDifferentType)141 TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
142 // Test that constants with the same value but different type are *not*
143 // commoned.
144 auto builder = HloComputation::Builder(TestName());
145 std::vector<HloInstruction*> constants;
146 constants.push_back(builder.AddInstruction(
147 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42))));
148 constants.push_back(builder.AddInstruction(
149 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42))));
150 constants.push_back(builder.AddInstruction(
151 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64>(42.0))));
152 constants.push_back(builder.AddInstruction(
153 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42.0))));
154 constants.push_back(builder.AddInstruction(
155 HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0))));
156 constants.push_back(builder.AddInstruction(
157 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
158 // Duplicate the float constant to verify something happens.
159 constants.push_back(builder.AddInstruction(
160 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
161
162 const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
163 for (int64 i = 0; i < constants.size(); ++i) {
164 constants[i] = builder.AddInstruction(
165 HloInstruction::CreateConvert(shape_r0, constants[i]));
166 }
167 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
168 shape_r0, HloOpcode::kAdd, constants[0], constants[1]));
169 for (int64 i = 2; i < constants.size(); ++i) {
170 root = builder.AddInstruction(HloInstruction::CreateBinary(
171 shape_r0, HloOpcode::kAdd, root, constants[i]));
172 }
173
174 auto module = CreateNewVerifiedModule();
175 auto computation = module->AddEntryComputation(builder.Build());
176
177 EXPECT_EQ(20, computation->instruction_count());
178
179 HloCSE cse(/*is_layout_sensitive=*/false);
180 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
181
182 // CSE will remove both the second float(42.0f) and the corresponding
183 // convert/cast.
184 EXPECT_EQ(18, computation->instruction_count());
185 }
186
TEST_F(HloCseTest,NonscalarConstants)187 TEST_F(HloCseTest, NonscalarConstants) {
188 // Test that identical nonscalar constants are merged.
189 auto builder = HloComputation::Builder(TestName());
190 auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
191 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
192 auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
193 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
194 // Create a constant which has the same shape but a different value.
195 auto uncommon_constant =
196 builder.AddInstruction(HloInstruction::CreateConstant(
197 LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
198
199 // Tie the constants together with a tuple. This makes it easier to refer to
200 // the constant instructions via their use.
201 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
202 {common_constant1, common_constant2, uncommon_constant}));
203
204 auto module = CreateNewVerifiedModule();
205 auto computation = module->AddEntryComputation(builder.Build());
206
207 EXPECT_EQ(4, computation->instruction_count());
208 EXPECT_THAT(tuple,
209 op::Tuple(common_constant1, common_constant2, uncommon_constant));
210
211 HloCSE cse(/*is_layout_sensitive=*/false);
212 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
213
214 EXPECT_EQ(3, computation->instruction_count());
215 auto first_operand = tuple->operand(0);
216 EXPECT_THAT(first_operand,
217 ::testing::AnyOf(common_constant1, common_constant2));
218 EXPECT_THAT(tuple,
219 op::Tuple(first_operand, first_operand, uncommon_constant));
220 }
221
TEST_F(HloCseTest,IdenticalInstructions)222 TEST_F(HloCseTest, IdenticalInstructions) {
223 // Test that three identical instructions are commoned.
224 auto builder = HloComputation::Builder(TestName());
225 auto constant = builder.AddInstruction(
226 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
227 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
228 constant->shape(), HloOpcode::kExp, constant));
229 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
230 constant->shape(), HloOpcode::kExp, constant));
231 auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary(
232 constant->shape(), HloOpcode::kExp, constant));
233 auto tuple =
234 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3}));
235
236 auto module = CreateNewVerifiedModule();
237 auto computation = module->AddEntryComputation(builder.Build());
238
239 EXPECT_EQ(5, computation->instruction_count());
240 EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
241
242 HloCSE cse(/*is_layout_sensitive=*/true);
243 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
244
245 EXPECT_EQ(3, computation->instruction_count());
246 auto first_operand = tuple->operand(0);
247 EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3));
248 EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand));
249 }
250
251 // Test two identical while loops with same inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesSameInput)252 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
253 const char* const hlo_string = R"(
254 HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
255
256 %body (param: (f32[], f32[])) -> (f32[], f32[]) {
257 %param = (f32[], f32[]) parameter(0)
258 %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
259 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
260 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
261 ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
262 }
263
264 %condition (param.1: (f32[], f32[])) -> pred[] {
265 %param.1 = (f32[], f32[]) parameter(0)
266 ROOT %constant = pred[] constant(false)
267 }
268
269 %condition.1 (param.2: (f32[], f32[])) -> pred[] {
270 %param.2 = (f32[], f32[]) parameter(0)
271 ROOT %constant.1 = pred[] constant(false)
272 }
273
274 ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput () -> (f32[], f32[])
275 { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %tuple.1 =
276 (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) %while = (f32[],
277 f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT
278 %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
279 condition=%condition.1, body=%body
280 })";
281
282 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
283 auto computation = m->entry_computation();
284
285 EXPECT_EQ(5, computation->instruction_count());
286 HloCSE cse(true);
287 EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
288 EXPECT_EQ(4, computation->instruction_count());
289 }
290
291 // Test two while loops with same conditions, same inputs, but different
292 // bodies
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsSameInputAndDifferentBodies)293 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
294 const char* const hlo_string = R"(
295 HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
296
297 %body (param: (f32[], f32[])) -> (f32[], f32[]) {
298 %param = (f32[], f32[]) parameter(0)
299 %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
300 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
301 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
302 ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
303 }
304
305 %body2 (param.1: (f32[], f32[])) -> (f32[], f32[]) {
306 %param.1 = (f32[], f32[]) parameter(0)
307 %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %param.1),
308 index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %param.1),
309 index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[]
310 %get-tuple-element.3) ROOT %tuple.2 = (f32[], f32[]) tuple(f32[]
311 %get-tuple-element.2, f32[] %sub)
312 }
313
314 %condition (param.2: (f32[], f32[])) -> pred[] {
315 %param.2 = (f32[], f32[]) parameter(0)
316 ROOT %constant = pred[] constant(false)
317 }
318
319 %condition.1 (param.3: (f32[], f32[])) -> pred[] {
320 %param.3 = (f32[], f32[]) parameter(0)
321 ROOT %constant.1 = pred[] constant(false)
322 }
323
324 ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies () ->
325 (f32[], f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
326 %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
327 %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
328 condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
329 f32[]) %tuple.1), condition=%condition.1, body=%body2
330 })";
331
332 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
333 auto computation = m->entry_computation();
334
335 EXPECT_EQ(5, computation->instruction_count());
336 HloCSE cse(true);
337 EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
338 EXPECT_EQ(5, computation->instruction_count());
339 }
340
341 // Test two identical while loops with different inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesDifferentInput)342 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
343 const char* const hlo_string = R"(
344 HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
345
346 %body (param: (f32[], f32[])) -> (f32[], f32[]) {
347 %param = (f32[], f32[]) parameter(0)
348 %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
349 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
350 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
351 ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
352 }
353
354 %condition (param.1: (f32[], f32[])) -> pred[] {
355 %param.1 = (f32[], f32[]) parameter(0)
356 ROOT %constant = pred[] constant(false)
357 }
358
359 %condition.1 (param.2: (f32[], f32[])) -> pred[] {
360 %param.2 = (f32[], f32[]) parameter(0)
361 ROOT %constant.1 = pred[] constant(false)
362 }
363
364 ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput () -> (f32[],
365 f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
366 %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
367 %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
368 condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 =
369 f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[]
370 %constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2),
371 condition=%condition.1, body=%body
372 })";
373
374 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
375 auto computation = m->entry_computation();
376
377 EXPECT_EQ(8, computation->instruction_count());
378 HloCSE cse(true);
379 EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
380 EXPECT_EQ(8, computation->instruction_count());
381 }
382
383 // Test two while loops with identical bodies and same inputs, but different
384 // conditions
TEST_F(HloCseTest,WhileLoopsIdenticalBodiesAndInputDifferntConditions)385 TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
386 const char* const hlo_string = R"(
387 HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
388
389 %body (param: (f32[], f32[])) -> (f32[], f32[]) {
390 %param = (f32[], f32[]) parameter(0)
391 %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
392 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
393 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
394 ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
395 }
396
397 %condition (param.1: (f32[], f32[])) -> pred[] {
398 %param.1 = (f32[], f32[]) parameter(0)
399 ROOT %constant = pred[] constant(false)
400 }
401
402 %condition.1 (param.2: (f32[], f32[])) -> pred[] {
403 %param.2 = (f32[], f32[]) parameter(0)
404 ROOT %constant.1 = pred[] constant(true)
405 }
406
407 ENTRY %WhileLoopsIdenticalBodiesAndInputDifferntConditions () -> (f32[],
408 f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
409 %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
410 %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
411 condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
412 f32[]) %tuple.1), condition=%condition.1, body=%body
413 })";
414
415 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
416 auto computation = m->entry_computation();
417
418 EXPECT_EQ(5, computation->instruction_count());
419 HloCSE cse(true);
420 EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
421 EXPECT_EQ(5, computation->instruction_count());
422 }
423
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsSensitive)424 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
425 // Test that two identical instructions with different layouts are *not*
426 // commoned if the pass is layout sensitive.
427 auto builder = HloComputation::Builder(TestName());
428 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
429 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
430
431 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
432 constant->shape(), HloOpcode::kExp, constant));
433 *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
434
435 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
436 constant->shape(), HloOpcode::kExp, constant));
437 *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
438
439 auto tuple =
440 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
441
442 auto module = CreateNewVerifiedModule();
443 auto computation = module->AddEntryComputation(builder.Build());
444
445 EXPECT_EQ(4, computation->instruction_count());
446 EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
447
448 HloCSE cse(/*is_layout_sensitive=*/true);
449 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
450
451 EXPECT_EQ(4, computation->instruction_count());
452 EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
453 }
454
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsInsensitive)455 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
456 // Test that two identical instructions with different layouts are commoned if
457 // the pass is layout insensitive.
458 auto builder = HloComputation::Builder(TestName());
459 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
460 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
461
462 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
463 constant->shape(), HloOpcode::kExp, constant));
464 *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
465
466 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
467 constant->shape(), HloOpcode::kExp, constant));
468 *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
469
470 auto tuple =
471 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
472
473 auto module = CreateNewVerifiedModule();
474 auto computation = module->AddEntryComputation(builder.Build());
475
476 EXPECT_EQ(4, computation->instruction_count());
477 EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
478
479 HloCSE cse(/*is_layout_sensitive=*/false);
480 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
481
482 EXPECT_EQ(3, computation->instruction_count());
483 auto first_operand = tuple->operand(0);
484 EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2));
485 EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand));
486 }
487
TEST_F(HloCseTest,FusionInternalCSE)488 TEST_F(HloCseTest, FusionInternalCSE) {
489 // Test that we can CSE expressions that live within a fusion node
490 // computation.
491 auto module = CreateNewVerifiedModule();
492 auto builder = HloComputation::Builder(TestName());
493
494 const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
495 auto param0 = builder.AddInstruction(
496 HloInstruction::CreateParameter(0, shape_r0, "p0"));
497 auto param1 = builder.AddInstruction(
498 HloInstruction::CreateParameter(1, shape_r0, "p1"));
499 auto add1 = builder.AddInstruction(
500 HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
501 auto add2 = builder.AddInstruction(
502 HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
503 auto mul = builder.AddInstruction(
504 HloInstruction::CreateBinary(shape_r0, HloOpcode::kMultiply, add1, add2));
505
506 auto computation = module->AddEntryComputation(builder.Build());
507 auto fused_computation =
508 computation
509 ->CreateFusionInstruction({mul, add1, add2},
510 HloInstruction::FusionKind::kLoop)
511 ->fused_instructions_computation();
512
513 EXPECT_EQ(5, fused_computation->instruction_count());
514 HloCSE cse(/*is_layout_sensitive=*/false);
515 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
516 EXPECT_EQ(4, fused_computation->instruction_count());
517
518 auto root = fused_computation->root_instruction();
519 EXPECT_THAT(root, op::Multiply(root->operand(0), root->operand(0)));
520 }
521
TEST_F(HloCseTest,IdenticalExpressions)522 TEST_F(HloCseTest, IdenticalExpressions) {
523 // Test that two identical expressions are commoned. Build the following
524 // computation:
525 //
526 // constant = 42.0
527 // negate1 = neg(constant)
528 // exp1 = exp(constant)
529 // add1 = add(negate1, exp1)
530 // negate2 = neg(constant)
531 // exp2 = exp(constant)
532 // add2 = add(negate2, exp2)
533 // tuple = tuple(add1, add2)
534 //
535 // The *1 instructions should be merged with the *2 instructions.
536 auto builder = HloComputation::Builder(TestName());
537 auto constant = builder.AddInstruction(
538 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
539
540 auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
541 constant->shape(), HloOpcode::kNegate, constant));
542 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
543 constant->shape(), HloOpcode::kExp, constant));
544 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
545 constant->shape(), HloOpcode::kAdd, negate1, exp1));
546
547 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
548 constant->shape(), HloOpcode::kNegate, constant));
549 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
550 constant->shape(), HloOpcode::kExp, constant));
551 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
552 constant->shape(), HloOpcode::kAdd, negate2, exp2));
553
554 auto tuple =
555 builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
556
557 auto module = CreateNewVerifiedModule();
558 auto computation = module->AddEntryComputation(builder.Build());
559
560 EXPECT_EQ(8, computation->instruction_count());
561 EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
562
563 HloCSE cse(/*is_layout_sensitive=*/false);
564 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
565
566 EXPECT_EQ(5, computation->instruction_count());
567 auto operand = tuple->operand(0);
568 EXPECT_THAT(tuple, op::Tuple(operand, operand));
569 EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp()));
570 }
571
TEST_F(HloCseTest,DoNotCombineRng)572 TEST_F(HloCseTest, DoNotCombineRng) {
573 // Test that two RNG ops are not commoned.
574 auto builder = HloComputation::Builder(TestName());
575 auto constant1 = builder.AddInstruction(
576 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
577 auto constant2 = builder.AddInstruction(
578 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
579 auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
580 ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
581 {constant1, constant2}));
582 auto rng2 = builder.AddInstruction(HloInstruction::CreateRng(
583 ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
584 {constant1, constant2}));
585
586 builder.AddInstruction(HloInstruction::CreateBinary(
587 constant1->shape(), HloOpcode::kAdd, rng1, rng2));
588
589 auto module = CreateNewVerifiedModule();
590 auto computation = module->AddEntryComputation(builder.Build());
591
592 HloInstruction* root = computation->root_instruction();
593 EXPECT_THAT(root, op::Add(rng1, rng2));
594
595 uint32 count_before = computation->instruction_count();
596
597 HloCSE cse(/*is_layout_sensitive=*/false);
598 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
599
600 uint32 count_after = computation->instruction_count();
601 EXPECT_EQ(count_before, count_after);
602 root = computation->root_instruction();
603 EXPECT_THAT(root, op::Add(rng1, rng2));
604 }
605
TEST_F(HloCseTest,DoNotCombineCallsToImpureFunctions)606 TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
607 // Test that two calls to an impure function are not commoned. RNG
608 // is the source of the impurity.
609
610 auto module = CreateNewVerifiedModule();
611
612 // rng_function is an impure function because it does RNG.
613 HloComputation* rng_function = nullptr;
614 {
615 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
616 auto builder = HloComputation::Builder(TestName() + "_rng_fun");
617 auto constant1 = builder.AddInstruction(
618 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
619 auto constant2 = builder.AddInstruction(
620 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
621 auto rng = builder.AddInstruction(HloInstruction::CreateRng(
622 scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
623 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
624 0, ShapeUtil::MakeShape(F32, {}), "param"));
625 builder.AddInstruction(HloInstruction::CreateBinary(
626 scalar_shape, HloOpcode::kAdd, rng, param));
627 rng_function = module->AddEmbeddedComputation(builder.Build());
628 }
629
630 // Computation calls rng_function twice with the same parameter.
631 HloComputation* computation = nullptr;
632 {
633 auto builder = HloComputation::Builder(TestName());
634 auto constant = builder.AddInstruction(
635 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
636 auto rng1 = builder.AddInstruction(
637 HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
638 auto rng2 = builder.AddInstruction(
639 HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
640 builder.AddInstruction(HloInstruction::CreateBinary(
641 constant->shape(), HloOpcode::kAdd, rng1, rng2));
642 computation = module->AddEntryComputation(builder.Build());
643 }
644
645 EXPECT_EQ(4, computation->instruction_count());
646 HloInstruction* root = computation->root_instruction();
647 EXPECT_THAT(root, op::Add(op::Map(), op::Map()));
648
649 VLOG(3) << "before: " << module->ToString();
650
651 HloCSE cse(/*is_layout_sensitive=*/false);
652 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
653
654 VLOG(3) << "after: " << module->ToString();
655
656 EXPECT_EQ(4, computation->instruction_count());
657 root = computation->root_instruction();
658 EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant())));
659 }
660
TEST_F(HloCseTest,CompareComputations)661 TEST_F(HloCseTest, CompareComputations) {
662 const char* const hlo_string = R"(
663 HloModule m
664
665 add_computation {
666 add_lhs = f32[] parameter(0)
667 add_rhs = f32[] parameter(1)
668 ROOT add_root = f32[] add(add_lhs, add_rhs)
669 }
670
671 add_computation2 {
672 add_lhs2 = f32[] parameter(0)
673 add_rhs2 = f32[] parameter(1)
674 ROOT add_root2 = f32[] add(add_lhs2, add_rhs2)
675 }
676
677 ENTRY entry {
678 p = f32[10]{0} parameter(0)
679 c = f32[] constant(0)
680 r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation
681 r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2
682 ROOT f2 = (f32[],f32[]) tuple(r1, r2)
683 })";
684
685 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
686 HloCSE cse(/*is_layout_sensitive=*/false);
687 EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
688 HloInstruction* root = m->entry_computation()->root_instruction();
689 EXPECT_EQ(root->operand(0), root->operand(1));
690 }
691
TEST_F(HloCseTest,ConstantsSameValueInDifferentDomains)692 TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
693 // Test that constants with the same value but in different domains (disjoint
694 // in this case) are not collapsed.
695 auto builder = HloComputation::Builder(TestName());
696 builder.AddInstruction(
697 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
698 builder.AddInstruction(
699 HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
700
701 auto module = CreateNewVerifiedModule();
702 auto computation = module->AddEntryComputation(builder.Build());
703
704 EXPECT_EQ(2, computation->instruction_count());
705
706 HloCSE cse(/*is_layout_sensitive=*/false);
707 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
708
709 EXPECT_EQ(2, computation->instruction_count());
710 }
711
TEST_F(HloCseTest,Domain)712 TEST_F(HloCseTest, Domain) {
713 const char* const hlo_string = R"(
714 HloModule module
715 ENTRY %entry {
716 %param = f32[] parameter(0), sharding={maximal device=0}
717 %domain.0 = f32[] domain(%param),
718 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
719 %domain.1 = f32[] domain(%param),
720 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
721 %domain.2 = f32[] domain(%param),
722 domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}}
723 %negate.0 = f32[] negate(%domain.0)
724 %negate.1 = f32[] negate(%domain.1)
725 %negate.2 = f32[] negate(%domain.2)
726 %domain.3 = f32[] domain(%negate.0),
727 domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
728 %domain.4 = f32[] domain(%negate.1),
729 domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
730 %domain.5 = f32[] domain(%negate.2),
731 domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
732 %add = f32[] add(%domain.3, %domain.4)
733 ROOT %sub = f32[] subtract(%add, %domain.5)
734 })";
735
736 TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
737 HloCSE cse(/*is_layout_sensitive=*/false);
738 EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
739 const HloInstruction* sub = m->entry_computation()->root_instruction();
740 const HloInstruction* add = sub->operand(0);
741 EXPECT_EQ(add->operand(0), add->operand(1));
742 EXPECT_NE(add->operand(0), sub->operand(1));
743 EXPECT_NE(add->operand(1), sub->operand(1));
744 }
745
746 } // namespace
747 } // namespace xla
748