• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cse.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/tests/test_utils.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 
38 #include "tensorflow/compiler/xla/service/hlo_parser.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/core/platform/types.h"
41 
42 namespace op = xla::testing::opcode_matchers;
43 
44 namespace xla {
45 namespace {
46 
47 class HloCseTest : public HloTestBase {
48  protected:
HloCseTest()49   HloCseTest() {}
50 };
51 
TEST_F(HloCseTest,CombineTwoConstants)52 TEST_F(HloCseTest, CombineTwoConstants) {
53   // Test that two identical constants are commoned.
54   auto builder = HloComputation::Builder(TestName());
55   auto constant1 = builder.AddInstruction(
56       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
57   auto constant2 = builder.AddInstruction(
58       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
59   builder.AddInstruction(HloInstruction::CreateBinary(
60       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
61 
62   auto module = CreateNewVerifiedModule();
63   auto computation = module->AddEntryComputation(builder.Build());
64 
65   EXPECT_EQ(3, computation->instruction_count());
66 
67   HloCSE cse(/*is_layout_sensitive=*/false);
68   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
69 
70   EXPECT_EQ(2, computation->instruction_count());
71   HloInstruction* constant = *computation->instructions().begin();
72   EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
73 
74   auto result = ExecuteAndTransfer(module->Clone(), {});
75   auto expected = LiteralUtil::CreateR0<float>(84.0);
76   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
77 }
78 
TEST_F(HloCseTest,CombineTwoConstantsDifferentLayoutsAndInsensitive)79 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
80   // Test that two identical constants with different layouts are commoned if
81   // the pass is not layout sensitive.
82   auto builder = HloComputation::Builder(TestName());
83   auto constant1 = builder.AddInstruction(
84       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
85           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
86   auto constant2 = builder.AddInstruction(
87       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
88           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
89   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
90       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
91 
92   auto module = CreateNewVerifiedModule();
93   auto computation = module->AddEntryComputation(builder.Build());
94 
95   EXPECT_EQ(3, computation->instruction_count());
96   EXPECT_THAT(add, op::Add(constant1, constant2));
97 
98   HloCSE cse(/*is_layout_sensitive=*/false);
99   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
100 
101   EXPECT_EQ(2, computation->instruction_count());
102   auto first_operand = add->operand(0);
103   EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2));
104   EXPECT_THAT(add, op::Add(first_operand, first_operand));
105 
106   auto result = ExecuteAndTransfer(module->Clone(), {});
107   auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
108   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
109 }
110 
TEST_F(HloCseTest,CombineTwoConstantsDifferentLayoutsAndSensitive)111 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
112   // Test that two identical constants with different layouts are *not* commoned
113   // if the pass is layout sensitive.
114   auto builder = HloComputation::Builder(TestName());
115   auto constant1 = builder.AddInstruction(
116       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
117           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
118   auto constant2 = builder.AddInstruction(
119       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
120           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
121   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
122       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
123 
124   auto module = CreateNewVerifiedModule();
125   auto computation = module->AddEntryComputation(builder.Build());
126 
127   EXPECT_EQ(3, computation->instruction_count());
128   EXPECT_THAT(add, op::Add(constant1, constant2));
129 
130   HloCSE cse(/*is_layout_sensitive=*/true);
131   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
132 
133   EXPECT_EQ(3, computation->instruction_count());
134   EXPECT_THAT(add, op::Add(constant1, constant2));
135 
136   auto result = ExecuteAndTransfer(module->Clone(), {});
137   auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
138   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
139 }
140 
TEST_F(HloCseTest,ConstantsSameValueDifferentType)141 TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
142   // Test that constants with the same value but different type are *not*
143   // commoned.
144   auto builder = HloComputation::Builder(TestName());
145   std::vector<HloInstruction*> constants;
146   constants.push_back(builder.AddInstruction(
147       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42))));
148   constants.push_back(builder.AddInstruction(
149       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42))));
150   constants.push_back(builder.AddInstruction(
151       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64>(42.0))));
152   constants.push_back(builder.AddInstruction(
153       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42.0))));
154   constants.push_back(builder.AddInstruction(
155       HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0))));
156   constants.push_back(builder.AddInstruction(
157       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
158   // Duplicate the float constant to verify something happens.
159   constants.push_back(builder.AddInstruction(
160       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
161 
162   const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
163   for (int64 i = 0; i < constants.size(); ++i) {
164     constants[i] = builder.AddInstruction(
165         HloInstruction::CreateConvert(shape_r0, constants[i]));
166   }
167   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
168       shape_r0, HloOpcode::kAdd, constants[0], constants[1]));
169   for (int64 i = 2; i < constants.size(); ++i) {
170     root = builder.AddInstruction(HloInstruction::CreateBinary(
171         shape_r0, HloOpcode::kAdd, root, constants[i]));
172   }
173 
174   auto module = CreateNewVerifiedModule();
175   auto computation = module->AddEntryComputation(builder.Build());
176 
177   EXPECT_EQ(20, computation->instruction_count());
178 
179   HloCSE cse(/*is_layout_sensitive=*/false);
180   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
181 
182   // CSE will remove both the second float(42.0f) and the corresponding
183   // convert/cast.
184   EXPECT_EQ(18, computation->instruction_count());
185 }
186 
TEST_F(HloCseTest,NonscalarConstants)187 TEST_F(HloCseTest, NonscalarConstants) {
188   // Test that identical nonscalar constants are merged.
189   auto builder = HloComputation::Builder(TestName());
190   auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
191       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
192   auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
193       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
194   // Create a constant which has the same shape but a different value.
195   auto uncommon_constant =
196       builder.AddInstruction(HloInstruction::CreateConstant(
197           LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
198 
199   // Tie the constants together with a tuple. This makes it easier to refer to
200   // the constant instructions via their use.
201   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
202       {common_constant1, common_constant2, uncommon_constant}));
203 
204   auto module = CreateNewVerifiedModule();
205   auto computation = module->AddEntryComputation(builder.Build());
206 
207   EXPECT_EQ(4, computation->instruction_count());
208   EXPECT_THAT(tuple,
209               op::Tuple(common_constant1, common_constant2, uncommon_constant));
210 
211   HloCSE cse(/*is_layout_sensitive=*/false);
212   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
213 
214   EXPECT_EQ(3, computation->instruction_count());
215   auto first_operand = tuple->operand(0);
216   EXPECT_THAT(first_operand,
217               ::testing::AnyOf(common_constant1, common_constant2));
218   EXPECT_THAT(tuple,
219               op::Tuple(first_operand, first_operand, uncommon_constant));
220 }
221 
TEST_F(HloCseTest,IdenticalInstructions)222 TEST_F(HloCseTest, IdenticalInstructions) {
223   // Test that three identical instructions are commoned.
224   auto builder = HloComputation::Builder(TestName());
225   auto constant = builder.AddInstruction(
226       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
227   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
228       constant->shape(), HloOpcode::kExp, constant));
229   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
230       constant->shape(), HloOpcode::kExp, constant));
231   auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary(
232       constant->shape(), HloOpcode::kExp, constant));
233   auto tuple =
234       builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3}));
235 
236   auto module = CreateNewVerifiedModule();
237   auto computation = module->AddEntryComputation(builder.Build());
238 
239   EXPECT_EQ(5, computation->instruction_count());
240   EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
241 
242   HloCSE cse(/*is_layout_sensitive=*/true);
243   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
244 
245   EXPECT_EQ(3, computation->instruction_count());
246   auto first_operand = tuple->operand(0);
247   EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3));
248   EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand));
249 }
250 
251 // Test two identical while loops with same inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesSameInput)252 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
253   const char* const hlo_string = R"(
254     HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
255 
256     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
257       %param = (f32[], f32[]) parameter(0)
258       %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
259 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
260 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
261       ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
262     }
263 
264     %condition (param.1: (f32[], f32[])) -> pred[] {
265       %param.1 = (f32[], f32[]) parameter(0)
266       ROOT %constant = pred[] constant(false)
267     }
268 
269     %condition.1 (param.2: (f32[], f32[])) -> pred[] {
270       %param.2 = (f32[], f32[]) parameter(0)
271       ROOT %constant.1 = pred[] constant(false)
272     }
273 
274     ENTRY %WhileLoopsIdenticalConditionsAndBodiesSameInput () -> (f32[], f32[])
275 { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %tuple.1 =
276 (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3) %while = (f32[],
277 f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT
278 %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
279 condition=%condition.1, body=%body
280     })";
281 
282   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
283   auto computation = m->entry_computation();
284 
285   EXPECT_EQ(5, computation->instruction_count());
286   HloCSE cse(true);
287   EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
288   EXPECT_EQ(4, computation->instruction_count());
289 }
290 
291 // Test two while loops with same conditions, same inputs, but different
292 // bodies
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsSameInputAndDifferentBodies)293 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
294   const char* const hlo_string = R"(
295     HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
296 
297     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
298       %param = (f32[], f32[]) parameter(0)
299       %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
300 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
301 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
302       ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
303     }
304 
305     %body2 (param.1: (f32[], f32[])) -> (f32[], f32[]) {
306       %param.1 = (f32[], f32[]) parameter(0)
307       %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %param.1),
308 index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %param.1),
309 index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[]
310 %get-tuple-element.3) ROOT %tuple.2 = (f32[], f32[]) tuple(f32[]
311 %get-tuple-element.2, f32[] %sub)
312     }
313 
314     %condition (param.2: (f32[], f32[])) -> pred[] {
315       %param.2 = (f32[], f32[]) parameter(0)
316       ROOT %constant = pred[] constant(false)
317     }
318 
319     %condition.1 (param.3: (f32[], f32[])) -> pred[] {
320       %param.3 = (f32[], f32[]) parameter(0)
321       ROOT %constant.1 = pred[] constant(false)
322     }
323 
324     ENTRY %WhileLoopsIdenticalConditionsSameInputAndDifferentBodies () ->
325 (f32[], f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
326       %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
327       %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
328 condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
329 f32[]) %tuple.1), condition=%condition.1, body=%body2
330     })";
331 
332   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
333   auto computation = m->entry_computation();
334 
335   EXPECT_EQ(5, computation->instruction_count());
336   HloCSE cse(true);
337   EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
338   EXPECT_EQ(5, computation->instruction_count());
339 }
340 
341 // Test two identical while loops with different inputs
TEST_F(HloCseTest,WhileLoopsIdenticalConditionsAndBodiesDifferentInput)342 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
343   const char* const hlo_string = R"(
344     HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
345 
346     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
347       %param = (f32[], f32[]) parameter(0)
348       %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
349 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
350 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
351       ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
352     }
353 
354     %condition (param.1: (f32[], f32[])) -> pred[] {
355       %param.1 = (f32[], f32[]) parameter(0)
356       ROOT %constant = pred[] constant(false)
357     }
358 
359     %condition.1 (param.2: (f32[], f32[])) -> pred[] {
360       %param.2 = (f32[], f32[]) parameter(0)
361       ROOT %constant.1 = pred[] constant(false)
362     }
363 
364     ENTRY %WhileLoopsIdenticalConditionsAndBodiesDifferentInput () -> (f32[],
365 f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
366       %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
367       %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
368 condition=%condition, body=%body %constant.4 = f32[] constant(1) %constant.5 =
369 f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[]
370 %constant.5) ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.2),
371 condition=%condition.1, body=%body
372     })";
373 
374   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
375   auto computation = m->entry_computation();
376 
377   EXPECT_EQ(8, computation->instruction_count());
378   HloCSE cse(true);
379   EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
380   EXPECT_EQ(8, computation->instruction_count());
381 }
382 
383 // Test two while loops with identical bodies and same inputs, but different
384 // conditions
TEST_F(HloCseTest,WhileLoopsIdenticalBodiesAndInputDifferntConditions)385 TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
386   const char* const hlo_string = R"(
387     HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
388 
389     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
390       %param = (f32[], f32[]) parameter(0)
391       %get-tuple-element = f32[] get-tuple-element((f32[], f32[]) %param),
392 index=0 %get-tuple-element.1 = f32[] get-tuple-element((f32[], f32[]) %param),
393 index=1 %add = f32[] add(f32[] %get-tuple-element, f32[] %get-tuple-element.1)
394       ROOT %tuple = (f32[], f32[]) tuple(f32[] %get-tuple-element, f32[] %add)
395     }
396 
397     %condition (param.1: (f32[], f32[])) -> pred[] {
398       %param.1 = (f32[], f32[]) parameter(0)
399       ROOT %constant = pred[] constant(false)
400     }
401 
402     %condition.1 (param.2: (f32[], f32[])) -> pred[] {
403       %param.2 = (f32[], f32[]) parameter(0)
404       ROOT %constant.1 = pred[] constant(true)
405     }
406 
407     ENTRY %WhileLoopsIdenticalBodiesAndInputDifferntConditions () -> (f32[],
408 f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
409       %tuple.1 = (f32[], f32[]) tuple(f32[] %constant.2, f32[] %constant.3)
410       %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
411 condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
412 f32[]) %tuple.1), condition=%condition.1, body=%body
413     })";
414 
415   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
416   auto computation = m->entry_computation();
417 
418   EXPECT_EQ(5, computation->instruction_count());
419   HloCSE cse(true);
420   EXPECT_FALSE(cse.Run(m.get()).ValueOrDie());
421   EXPECT_EQ(5, computation->instruction_count());
422 }
423 
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsSensitive)424 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
425   // Test that two identical instructions with different layouts are *not*
426   // commoned if the pass is layout sensitive.
427   auto builder = HloComputation::Builder(TestName());
428   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
429       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
430 
431   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
432       constant->shape(), HloOpcode::kExp, constant));
433   *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
434 
435   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
436       constant->shape(), HloOpcode::kExp, constant));
437   *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
438 
439   auto tuple =
440       builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
441 
442   auto module = CreateNewVerifiedModule();
443   auto computation = module->AddEntryComputation(builder.Build());
444 
445   EXPECT_EQ(4, computation->instruction_count());
446   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
447 
448   HloCSE cse(/*is_layout_sensitive=*/true);
449   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
450 
451   EXPECT_EQ(4, computation->instruction_count());
452   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
453 }
454 
TEST_F(HloCseTest,IdenticalInstructionsDifferentLayoutsInsensitive)455 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
456   // Test that two identical instructions with different layouts are commoned if
457   // the pass is layout insensitive.
458   auto builder = HloComputation::Builder(TestName());
459   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
460       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
461 
462   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
463       constant->shape(), HloOpcode::kExp, constant));
464   *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
465 
466   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
467       constant->shape(), HloOpcode::kExp, constant));
468   *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
469 
470   auto tuple =
471       builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
472 
473   auto module = CreateNewVerifiedModule();
474   auto computation = module->AddEntryComputation(builder.Build());
475 
476   EXPECT_EQ(4, computation->instruction_count());
477   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
478 
479   HloCSE cse(/*is_layout_sensitive=*/false);
480   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
481 
482   EXPECT_EQ(3, computation->instruction_count());
483   auto first_operand = tuple->operand(0);
484   EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2));
485   EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand));
486 }
487 
TEST_F(HloCseTest,FusionInternalCSE)488 TEST_F(HloCseTest, FusionInternalCSE) {
489   // Test that we can CSE expressions that live within a fusion node
490   // computation.
491   auto module = CreateNewVerifiedModule();
492   auto builder = HloComputation::Builder(TestName());
493 
494   const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
495   auto param0 = builder.AddInstruction(
496       HloInstruction::CreateParameter(0, shape_r0, "p0"));
497   auto param1 = builder.AddInstruction(
498       HloInstruction::CreateParameter(1, shape_r0, "p1"));
499   auto add1 = builder.AddInstruction(
500       HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
501   auto add2 = builder.AddInstruction(
502       HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1));
503   auto mul = builder.AddInstruction(
504       HloInstruction::CreateBinary(shape_r0, HloOpcode::kMultiply, add1, add2));
505 
506   auto computation = module->AddEntryComputation(builder.Build());
507   auto fused_computation =
508       computation
509           ->CreateFusionInstruction({mul, add1, add2},
510                                     HloInstruction::FusionKind::kLoop)
511           ->fused_instructions_computation();
512 
513   EXPECT_EQ(5, fused_computation->instruction_count());
514   HloCSE cse(/*is_layout_sensitive=*/false);
515   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
516   EXPECT_EQ(4, fused_computation->instruction_count());
517 
518   auto root = fused_computation->root_instruction();
519   EXPECT_THAT(root, op::Multiply(root->operand(0), root->operand(0)));
520 }
521 
TEST_F(HloCseTest,IdenticalExpressions)522 TEST_F(HloCseTest, IdenticalExpressions) {
523   // Test that two identical expressions are commoned. Build the following
524   // computation:
525   //
526   //   constant = 42.0
527   //   negate1 = neg(constant)
528   //   exp1 = exp(constant)
529   //   add1 = add(negate1, exp1)
530   //   negate2 = neg(constant)
531   //   exp2 = exp(constant)
532   //   add2 = add(negate2, exp2)
533   //   tuple = tuple(add1, add2)
534   //
535   // The *1 instructions should be merged with the *2 instructions.
536   auto builder = HloComputation::Builder(TestName());
537   auto constant = builder.AddInstruction(
538       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
539 
540   auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
541       constant->shape(), HloOpcode::kNegate, constant));
542   auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
543       constant->shape(), HloOpcode::kExp, constant));
544   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
545       constant->shape(), HloOpcode::kAdd, negate1, exp1));
546 
547   auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
548       constant->shape(), HloOpcode::kNegate, constant));
549   auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
550       constant->shape(), HloOpcode::kExp, constant));
551   auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
552       constant->shape(), HloOpcode::kAdd, negate2, exp2));
553 
554   auto tuple =
555       builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
556 
557   auto module = CreateNewVerifiedModule();
558   auto computation = module->AddEntryComputation(builder.Build());
559 
560   EXPECT_EQ(8, computation->instruction_count());
561   EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
562 
563   HloCSE cse(/*is_layout_sensitive=*/false);
564   EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
565 
566   EXPECT_EQ(5, computation->instruction_count());
567   auto operand = tuple->operand(0);
568   EXPECT_THAT(tuple, op::Tuple(operand, operand));
569   EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp()));
570 }
571 
TEST_F(HloCseTest,DoNotCombineRng)572 TEST_F(HloCseTest, DoNotCombineRng) {
573   // Test that two RNG ops are not commoned.
574   auto builder = HloComputation::Builder(TestName());
575   auto constant1 = builder.AddInstruction(
576       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
577   auto constant2 = builder.AddInstruction(
578       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
579   auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
580       ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
581       {constant1, constant2}));
582   auto rng2 = builder.AddInstruction(HloInstruction::CreateRng(
583       ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
584       {constant1, constant2}));
585 
586   builder.AddInstruction(HloInstruction::CreateBinary(
587       constant1->shape(), HloOpcode::kAdd, rng1, rng2));
588 
589   auto module = CreateNewVerifiedModule();
590   auto computation = module->AddEntryComputation(builder.Build());
591 
592   HloInstruction* root = computation->root_instruction();
593   EXPECT_THAT(root, op::Add(rng1, rng2));
594 
595   uint32 count_before = computation->instruction_count();
596 
597   HloCSE cse(/*is_layout_sensitive=*/false);
598   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
599 
600   uint32 count_after = computation->instruction_count();
601   EXPECT_EQ(count_before, count_after);
602   root = computation->root_instruction();
603   EXPECT_THAT(root, op::Add(rng1, rng2));
604 }
605 
TEST_F(HloCseTest,DoNotCombineCallsToImpureFunctions)606 TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
607   // Test that two calls to an impure function are not commoned. RNG
608   // is the source of the impurity.
609 
610   auto module = CreateNewVerifiedModule();
611 
612   // rng_function is an impure function because it does RNG.
613   HloComputation* rng_function = nullptr;
614   {
615     Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
616     auto builder = HloComputation::Builder(TestName() + "_rng_fun");
617     auto constant1 = builder.AddInstruction(
618         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
619     auto constant2 = builder.AddInstruction(
620         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
621     auto rng = builder.AddInstruction(HloInstruction::CreateRng(
622         scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
623     auto param = builder.AddInstruction(HloInstruction::CreateParameter(
624         0, ShapeUtil::MakeShape(F32, {}), "param"));
625     builder.AddInstruction(HloInstruction::CreateBinary(
626         scalar_shape, HloOpcode::kAdd, rng, param));
627     rng_function = module->AddEmbeddedComputation(builder.Build());
628   }
629 
630   // Computation calls rng_function twice with the same parameter.
631   HloComputation* computation = nullptr;
632   {
633     auto builder = HloComputation::Builder(TestName());
634     auto constant = builder.AddInstruction(
635         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
636     auto rng1 = builder.AddInstruction(
637         HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
638     auto rng2 = builder.AddInstruction(
639         HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
640     builder.AddInstruction(HloInstruction::CreateBinary(
641         constant->shape(), HloOpcode::kAdd, rng1, rng2));
642     computation = module->AddEntryComputation(builder.Build());
643   }
644 
645   EXPECT_EQ(4, computation->instruction_count());
646   HloInstruction* root = computation->root_instruction();
647   EXPECT_THAT(root, op::Add(op::Map(), op::Map()));
648 
649   VLOG(3) << "before: " << module->ToString();
650 
651   HloCSE cse(/*is_layout_sensitive=*/false);
652   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
653 
654   VLOG(3) << "after: " << module->ToString();
655 
656   EXPECT_EQ(4, computation->instruction_count());
657   root = computation->root_instruction();
658   EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant())));
659 }
660 
TEST_F(HloCseTest,CompareComputations)661 TEST_F(HloCseTest, CompareComputations) {
662   const char* const hlo_string = R"(
663     HloModule m
664 
665     add_computation {
666       add_lhs = f32[] parameter(0)
667       add_rhs = f32[] parameter(1)
668       ROOT add_root = f32[] add(add_lhs, add_rhs)
669     }
670 
671     add_computation2 {
672       add_lhs2 = f32[] parameter(0)
673       add_rhs2 = f32[] parameter(1)
674       ROOT add_root2 = f32[] add(add_lhs2, add_rhs2)
675     }
676 
677     ENTRY entry {
678       p = f32[10]{0} parameter(0)
679       c = f32[] constant(0)
680       r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation
681       r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2
682       ROOT f2 = (f32[],f32[]) tuple(r1, r2)
683     })";
684 
685   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
686   HloCSE cse(/*is_layout_sensitive=*/false);
687   EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
688   HloInstruction* root = m->entry_computation()->root_instruction();
689   EXPECT_EQ(root->operand(0), root->operand(1));
690 }
691 
TEST_F(HloCseTest,ConstantsSameValueInDifferentDomains)692 TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
693   // Test that constants with the same value but in different domains (disjoint
694   // in this case) are not collapsed.
695   auto builder = HloComputation::Builder(TestName());
696   builder.AddInstruction(
697       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
698   builder.AddInstruction(
699       HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
700 
701   auto module = CreateNewVerifiedModule();
702   auto computation = module->AddEntryComputation(builder.Build());
703 
704   EXPECT_EQ(2, computation->instruction_count());
705 
706   HloCSE cse(/*is_layout_sensitive=*/false);
707   EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
708 
709   EXPECT_EQ(2, computation->instruction_count());
710 }
711 
TEST_F(HloCseTest,Domain)712 TEST_F(HloCseTest, Domain) {
713   const char* const hlo_string = R"(
714 HloModule module
715 ENTRY %entry {
716   %param = f32[] parameter(0), sharding={maximal device=0}
717   %domain.0 = f32[] domain(%param),
718     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
719   %domain.1 = f32[] domain(%param),
720     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
721   %domain.2 = f32[] domain(%param),
722     domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}}
723   %negate.0 = f32[] negate(%domain.0)
724   %negate.1 = f32[] negate(%domain.1)
725   %negate.2 = f32[] negate(%domain.2)
726   %domain.3 = f32[] domain(%negate.0),
727     domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
728   %domain.4 = f32[] domain(%negate.1),
729     domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
730   %domain.5 = f32[] domain(%negate.2),
731     domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
732   %add = f32[] add(%domain.3, %domain.4)
733   ROOT %sub = f32[] subtract(%add, %domain.5)
734 })";
735 
736   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string));
737   HloCSE cse(/*is_layout_sensitive=*/false);
738   EXPECT_TRUE(cse.Run(m.get()).ValueOrDie());
739   const HloInstruction* sub = m->entry_computation()->root_instruction();
740   const HloInstruction* add = sub->operand(0);
741   EXPECT_EQ(add->operand(0), add->operand(1));
742   EXPECT_NE(add->operand(0), sub->operand(1));
743   EXPECT_NE(add->operand(1), sub->operand(1));
744 }
745 
746 }  // namespace
747 }  // namespace xla
748