• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17 
18 #include <memory>
19 #include <set>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
35 
36 namespace xla {
37 
38 namespace {
39 
40 namespace m = match;
41 namespace op = xla::testing::opcode_matchers;
42 using ::testing::ElementsAre;
43 using ::testing::UnorderedElementsAre;
44 
45 class HloComputationTest : public HloTestBase {
46  protected:
HloComputationTest()47   HloComputationTest() {}
48 
49   // Create a computation which takes a scalar and returns its negation.
CreateNegateComputation()50   std::unique_ptr<HloComputation> CreateNegateComputation() {
51     auto builder = HloComputation::Builder("Negate");
52     auto param = builder.AddInstruction(
53         HloInstruction::CreateParameter(0, r0f32_, "param0"));
54     builder.AddInstruction(
55         HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
56     return builder.Build();
57   }
58 
59   // Creates a computation which calls map with the given computation.
CreateMapComputation(HloComputation * map_computation)60   std::unique_ptr<HloComputation> CreateMapComputation(
61       HloComputation* map_computation) {
62     auto builder = HloComputation::Builder("Map");
63     auto param = builder.AddInstruction(
64         HloInstruction::CreateParameter(0, r0f32_, "param0"));
65     builder.AddInstruction(
66         HloInstruction::CreateMap(r0f32_, {param}, map_computation));
67     return builder.Build();
68   }
69 
70   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
71 };
72 
TEST_F(HloComputationTest,GetEmbeddedComputationsEmpty)73 TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) {
74   auto module = CreateNewVerifiedModule();
75   auto negate_computation =
76       module->AddEntryComputation(CreateNegateComputation());
77   EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
78 }
79 
TEST_F(HloComputationTest,GetEmbeddedComputationsOneComputation)80 TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) {
81   // Create computation which calls one other computation.
82   auto module = CreateNewVerifiedModule();
83   auto negate_computation =
84       module->AddEmbeddedComputation(CreateNegateComputation());
85   auto map_computation =
86       module->AddEntryComputation(CreateMapComputation(negate_computation));
87   EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
88   EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(),
89               ElementsAre(negate_computation));
90 }
91 
TEST_F(HloComputationTest,GetEmbeddedComputationsDiamond)92 TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) {
93   // Create computations with a diamond-shaped callgraph.
94   auto module = CreateNewVerifiedModule();
95   auto negate_computation =
96       module->AddEmbeddedComputation(CreateNegateComputation());
97   auto map1_computation =
98       module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
99   auto map2_computation =
100       module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
101 
102   auto builder = HloComputation::Builder(TestName());
103   auto param = builder.AddInstruction(
104       HloInstruction::CreateParameter(0, r0f32_, "param0"));
105   auto map1 = builder.AddInstruction(
106       HloInstruction::CreateMap(r0f32_, {param}, map1_computation));
107   auto map2 = builder.AddInstruction(
108       HloInstruction::CreateMap(r0f32_, {param}, map2_computation));
109   builder.AddInstruction(
110       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
111   auto computation = module->AddEntryComputation(builder.Build());
112 
113   auto embedded_computations = computation->MakeEmbeddedComputationsList();
114   EXPECT_EQ(3, embedded_computations.size());
115   // GetEmbeddedComputations returns a post order of the embedded computations,
116   // so the negate computation must come first.
117   EXPECT_EQ(negate_computation, *embedded_computations.begin());
118   EXPECT_THAT(embedded_computations,
119               UnorderedElementsAre(negate_computation, map1_computation,
120                                    map2_computation));
121 }
122 
TEST_F(HloComputationTest,PostOrderSingleton)123 TEST_F(HloComputationTest, PostOrderSingleton) {
124   // Test GetInstructionPostOrder for a computation with one instruction.
125   auto builder = HloComputation::Builder(TestName());
126   auto constant = builder.AddInstruction(
127       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
128   auto module = CreateNewVerifiedModule();
129   auto computation = module->AddEntryComputation(builder.Build());
130   EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
131 }
132 
TEST_F(HloComputationTest,PostOrderSimple)133 TEST_F(HloComputationTest, PostOrderSimple) {
134   // Test GetInstructionPostOrder for a computation with a chain of
135   // instructions.
136   auto builder = HloComputation::Builder(TestName());
137   auto constant = builder.AddInstruction(
138       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
139   auto negate1 = builder.AddInstruction(
140       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
141   auto negate2 = builder.AddInstruction(
142       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
143   auto module = CreateNewVerifiedModule();
144   auto computation = module->AddEntryComputation(builder.Build());
145   EXPECT_THAT(computation->MakeInstructionPostOrder(),
146               ElementsAre(constant, negate1, negate2));
147 }
148 
TEST_F(HloComputationTest,PostOrderDisconnectedInstructions)149 TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
150   // Test GetInstructionPostOrder for a computation with multiple instructions
151   // which are not connected.
152   auto builder = HloComputation::Builder(TestName());
153   auto constant1 = builder.AddInstruction(
154       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
155   auto constant2 = builder.AddInstruction(
156       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
157   auto constant3 = builder.AddInstruction(
158       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
159   auto constant4 = builder.AddInstruction(
160       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
161   auto module = CreateNewVerifiedModule();
162   auto computation = module->AddEntryComputation(builder.Build());
163   EXPECT_THAT(computation->MakeInstructionPostOrder(),
164               UnorderedElementsAre(constant1, constant2, constant3, constant4));
165 }
166 
TEST_F(HloComputationTest,PostOrderWithMultipleRoots)167 TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
168   // Test GetInstructionPostOrder for a computation with multiple instructions
169   // which are not connected.
170   auto builder = HloComputation::Builder(TestName());
171   auto constant1 = builder.AddInstruction(
172       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
173   auto constant2 = builder.AddInstruction(
174       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
175   auto constant3 = builder.AddInstruction(
176       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
177   auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
178       r0f32_, HloOpcode::kAdd, constant1, constant2));
179   auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
180       r0f32_, HloOpcode::kAdd, constant2, constant3));
181   auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
182       r0f32_, HloOpcode::kAdd, constant1, constant3));
183   auto module = CreateNewVerifiedModule();
184   auto computation = module->AddEntryComputation(builder.Build());
185   auto post_order = computation->MakeInstructionPostOrder();
186   EXPECT_EQ(6, post_order.size());
187   EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3,
188                                                add1, add2, add3));
189 }
190 
TEST_F(HloComputationTest,VisitWithMultipleRoots)191 TEST_F(HloComputationTest, VisitWithMultipleRoots) {
192   // Test that Accept visits all instructions in the computation even if the
193   // computation has multiple roots (dead code).
194   auto builder = HloComputation::Builder(TestName());
195   auto constant1 = builder.AddInstruction(
196       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
197   auto constant2 = builder.AddInstruction(
198       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
199   auto constant3 = builder.AddInstruction(
200       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
201   // Add three disconnected add expressions.
202   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
203                                                       constant1, constant2));
204   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
205                                                       constant2, constant3));
206   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
207                                                       constant1, constant3));
208   auto module = CreateNewVerifiedModule();
209   auto computation = module->AddEntryComputation(builder.Build());
210   // Visitor which keeps track of which instructions have been visited.
211   class TestVisitor : public DfsHloVisitorWithDefault {
212    public:
213     explicit TestVisitor(HloComputation* computation)
214         : computation_(computation) {}
215 
216     Status DefaultAction(HloInstruction* hlo_instruction) override {
217       EXPECT_FALSE(visited_set_.contains(hlo_instruction));
218       visited_set_.insert(hlo_instruction);
219       last_visited_ = hlo_instruction;
220       return OkStatus();
221     }
222 
223     Status FinishVisit(HloInstruction* root) override {
224       EXPECT_EQ(computation_->root_instruction(), root);
225       ++finish_visit_calls_;
226       return OkStatus();
227     }
228 
229     HloComputation* computation_;
230     absl::flat_hash_set<HloInstruction*> visited_set_;
231     int64_t finish_visit_calls_ = 0;
232     HloInstruction* last_visited_ = nullptr;
233   };
234 
235   TestVisitor visitor(computation);
236   EXPECT_IS_OK(computation->Accept(&visitor));
237 
238   EXPECT_EQ(6, visitor.visited_set_.size());
239   EXPECT_EQ(1, visitor.finish_visit_calls_);
240   EXPECT_EQ(computation->root_instruction(), visitor.last_visited_);
241 }
242 
TEST_F(HloComputationTest,DeepCopyArray)243 TEST_F(HloComputationTest, DeepCopyArray) {
244   // Test that DeepCopyInstruction properly copies an array.
245   auto builder = HloComputation::Builder(TestName());
246   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
247       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
248   auto module = CreateNewVerifiedModule();
249   auto computation = module->AddEntryComputation(builder.Build());
250   auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
251 
252   EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant))));
253 }
254 
TEST_F(HloComputationTest,DeepCopyTuple)255 TEST_F(HloComputationTest, DeepCopyTuple) {
256   // Test that DeepCopyInstruction properly copies a tuple.
257   auto builder = HloComputation::Builder(TestName());
258   auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
259       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
260   auto constant2 = builder.AddInstruction(
261       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
262   auto tuple = builder.AddInstruction(
263       HloInstruction::CreateTuple({constant1, constant2}));
264 
265   auto module = CreateNewVerifiedModule();
266   auto computation = module->AddEntryComputation(builder.Build());
267   auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
268 
269   EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple(
270                               m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
271                               m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
272   EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index());
273   EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index());
274 }
275 
TEST_F(HloComputationTest,DeepCopyArrayAtIndices)276 TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
277   // Test that DeepCopyInstruction properly handles an array when the indices to
278   // copy are specified.
279   auto builder = HloComputation::Builder(TestName());
280   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
281       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
282   auto computation = builder.Build();
283 
284   {
285     // If the index is true, then a copy should be made.
286     ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/true);
287     EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy)
288                     .ValueOrDie(),
289                 GmockMatch(m::Copy(m::Op().Is(constant))));
290   }
291 
292   {
293     // If the index is false, then no copy should be made.
294     ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/false);
295     EXPECT_EQ(computation->DeepCopyInstruction(constant, &indices_to_copy)
296                   .ValueOrDie(),
297               constant);
298   }
299 }
300 
TEST_F(HloComputationTest,DeepCopyTupleAtIndices)301 TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
302   // Test that DeepCopyInstruction properly copies elements of a tuple as
303   // specified by the given indices.
304   auto builder = HloComputation::Builder(TestName());
305   auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
306       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
307   auto constant2 = builder.AddInstruction(
308       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
309   auto tuple = builder.AddInstruction(
310       HloInstruction::CreateTuple({constant1, constant2}));
311   auto computation = builder.Build();
312 
313   {
314     // All true values should copy all array elements.
315     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/true);
316     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
317                                             /*init_value=*/nullptr);
318     HloInstruction* deep_copy =
319         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
320             .ValueOrDie();
321 
322     EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
323                                m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
324                                    .Is(copies_added.element({0})),
325                                m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
326                                    .Is(copies_added.element({1})))));
327   }
328 
329   {
330     // All false elements should copy no array elements, but the GTE and tuple
331     // instruction scaffolding should be built.
332     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
333     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
334                                             /*init_value=*/nullptr);
335     HloInstruction* deep_copy =
336         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
337             .ValueOrDie();
338 
339     EXPECT_THAT(deep_copy,
340                 GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)),
341                                     m::GetTupleElement(m::Op().Is(tuple)))));
342     EXPECT_TRUE(copies_added.element({}) == nullptr);
343     EXPECT_TRUE(copies_added.element({0}) == nullptr);
344     EXPECT_TRUE(copies_added.element({1}) == nullptr);
345   }
346 
347   {
348     // Verify one element copied, the other not.
349     ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
350     *indices_to_copy.mutable_element({0}) = true;
351     ShapeTree<HloInstruction*> copies_added(tuple->shape(),
352                                             /*init_value=*/nullptr);
353     HloInstruction* deep_copy =
354         computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
355             .ValueOrDie();
356 
357     EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
358                                m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
359                                m::GetTupleElement(m::Op().Is(tuple)))));
360     EXPECT_TRUE(copies_added.element({}) == nullptr);
361     EXPECT_TRUE(copies_added.element({0}) != nullptr);
362     EXPECT_TRUE(copies_added.element({1}) == nullptr);
363   }
364 }
365 
TEST_F(HloComputationTest,DeepCopyToken)366 TEST_F(HloComputationTest, DeepCopyToken) {
367   // Test that DeepCopyInstruction properly handles tokens which should not be
368   // copied.
369   auto builder = HloComputation::Builder(TestName());
370   auto token = builder.AddInstruction(HloInstruction::CreateToken());
371   auto module = CreateNewVerifiedModule();
372   auto computation = module->AddEntryComputation(builder.Build());
373   auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
374 
375   // No copy should be added.
376   EXPECT_THAT(copy, GmockMatch(m::AfterAll()));
377 }
378 
TEST_F(HloComputationTest,DeepCopyTokenTuple)379 TEST_F(HloComputationTest, DeepCopyTokenTuple) {
380   // Test that DeepCopyInstruction properly handles tokens which should not be
381   // copied.
382   auto builder = HloComputation::Builder(TestName());
383   auto token = builder.AddInstruction(HloInstruction::CreateToken());
384   auto constant = builder.AddInstruction(
385       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
386   auto tuple =
387       builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
388   auto module = CreateNewVerifiedModule();
389   auto computation = module->AddEntryComputation(builder.Build());
390   auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
391 
392   // Only the array (second tuple element) should be copied. The token is passed
393   // through transparently.
394   EXPECT_THAT(copy, GmockMatch(m::Tuple(
395                         m::GetTupleElement(m::Op().Is(tuple)),
396                         m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
397 }
398 
TEST_F(HloComputationTest,CycleDetection)399 TEST_F(HloComputationTest, CycleDetection) {
400   // Test whether the visitor can detect cycles in the graph.
401   auto builder = HloComputation::Builder(TestName());
402   auto constant = builder.AddInstruction(
403       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
404   auto negate = builder.AddInstruction(
405       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
406   auto add = builder.AddInstruction(
407       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate));
408   auto module = CreateNewUnverifiedModule();
409   auto computation = module->AddEntryComputation(builder.Build());
410   // Add a control dependency to create a cycle.
411   ASSERT_IS_OK(add->AddControlDependencyTo(negate));
412 
413   auto instructions = computation->MakeInstructionPostOrder();
414   EXPECT_EQ(3, instructions.size());
415 
416   FunctionVisitor visitor(
417       [](HloInstruction* instruction) { return OkStatus(); });
418   auto visit_status = computation->Accept(&visitor);
419   ASSERT_FALSE(visit_status.ok());
420   ASSERT_THAT(visit_status.error_message(),
421               ::testing::ContainsRegex("cycle is detecte"));
422 }
423 
TEST_F(HloComputationTest,RemoveInstructionWithDuplicateOperand)424 TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
425   // Test RemoveInstructionAndUnusedOperands with an instruction which has a
426   // duplicated (dead) operand. This verifies that the operand is not deleted
427   // twice.
428   auto builder = HloComputation::Builder(TestName());
429   auto constant = builder.AddInstruction(
430       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
431   auto dead_negate = builder.AddInstruction(
432       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
433   auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
434       r0f32_, HloOpcode::kAdd, dead_negate, dead_negate));
435   auto negate = builder.AddInstruction(
436       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
437   auto module = CreateNewVerifiedModule();
438   auto computation = module->AddEntryComputation(builder.Build());
439   EXPECT_EQ(4, computation->instruction_count());
440   EXPECT_THAT(computation->root_instruction(),
441               GmockMatch(m::Negate(m::Op().Is(constant))));
442   EXPECT_EQ(negate, computation->root_instruction());
443 
444   ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add));
445 
446   EXPECT_EQ(2, computation->instruction_count());
447   EXPECT_THAT(computation->root_instruction(),
448               GmockMatch(m::Negate(m::Op().Is(constant))));
449   EXPECT_EQ(negate, computation->root_instruction());
450 }
451 
TEST_F(HloComputationTest,CloneWithControlDependency)452 TEST_F(HloComputationTest, CloneWithControlDependency) {
453   auto builder = HloComputation::Builder(TestName());
454   auto constant1 = builder.AddInstruction(
455       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
456   auto constant2 = builder.AddInstruction(
457       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
458   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
459       r0f32_, HloOpcode::kAdd, constant1, constant2));
460 
461   auto param = builder.AddInstruction(
462       HloInstruction::CreateParameter(0, r0f32_, "param0"));
463   auto negate = builder.AddInstruction(
464       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
465   auto module = CreateNewVerifiedModule();
466   auto computation =
467       module->AddEntryComputation(builder.Build(/*root_instruction=*/add));
468 
469   TF_CHECK_OK(negate->AddControlDependencyTo(add));
470 
471   auto clone = computation->Clone();
472 
473   auto cloned_add = clone->root_instruction();
474   EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd);
475 
476   auto predecessors = cloned_add->control_predecessors();
477   EXPECT_EQ(1, predecessors.size());
478   EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode());
479   auto successors = predecessors[0]->control_successors();
480   EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add));
481 }
482 
TEST_F(HloComputationTest,CloneWithReplacements)483 TEST_F(HloComputationTest, CloneWithReplacements) {
484   auto builder = HloComputation::Builder(TestName());
485   Shape r0s64 = ShapeUtil::MakeShape(S64, {});
486   Shape r0s32 = ShapeUtil::MakeShape(S32, {});
487   Shape r0u32 = ShapeUtil::MakeShape(U32, {});
488   auto param0 = builder.AddInstruction(
489       HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs"));
490   auto param1 = builder.AddInstruction(
491       HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
492   auto param2 =
493       builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
494   auto lt = builder.AddInstruction(
495       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
496                                     param1, ComparisonDirection::kLt));
497   auto module = CreateNewVerifiedModule();
498   auto computation =
499       module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));
500   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
501       replacements;
502   replacements.emplace(param2,
503                        HloInstruction::CreateParameter(2, r0s32, "p.1"));
504   auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2");
505   std::vector<const HloInstruction*> extra_parameters{param3.get()};
506   auto clone =
507       computation->CloneWithReplacements(&replacements, extra_parameters);
508   ASSERT_EQ(clone->num_parameters(), 4);
509   EXPECT_TRUE(
510       ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_));
511   EXPECT_TRUE(
512       ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_));
513   EXPECT_TRUE(
514       ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32));
515   EXPECT_TRUE(
516       ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32));
517 }
518 
TEST_F(HloComputationTest,Stringification)519 TEST_F(HloComputationTest, Stringification) {
520   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
521   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
522   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
523   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
524 
525   HloComputation::Builder builder("TransposeDot");
526   HloInstruction* x =
527       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
528   HloInstruction* y =
529       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
530   HloInstruction* reshape =
531       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
532   DotDimensionNumbers dot_dnums;
533   dot_dnums.add_lhs_contracting_dimensions(1);
534   dot_dnums.add_rhs_contracting_dimensions(0);
535   PrecisionConfig precision_config;
536   precision_config.mutable_operand_precision()->Resize(
537       2, PrecisionConfig::DEFAULT);
538   builder.AddInstruction(
539       HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
540   auto module = CreateNewVerifiedModule();
541   auto* computation = module->AddEntryComputation(builder.Build());
542   computation->SetExecutionThread("MainThread");
543 
544   auto options = HloPrintOptions().set_print_metadata(false);
545   const std::string expected_computation =
546       R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
547   %x = f32[5,10]{1,0} parameter(0)
548   %y = f32[20,10]{1,0} parameter(1)
549   %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
550   ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
551 }, execution_thread="MainThread")";
552   EXPECT_EQ(computation->ToString(options), expected_computation);
553 }
554 
TEST_F(HloComputationTest,StringificationIndent)555 TEST_F(HloComputationTest, StringificationIndent) {
556   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
557   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
558   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
559   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
560 
561   HloComputation::Builder builder("TransposeDot");
562   HloInstruction* x =
563       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
564   HloInstruction* y =
565       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
566   HloInstruction* reshape =
567       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
568   DotDimensionNumbers dot_dnums;
569   dot_dnums.add_lhs_contracting_dimensions(1);
570   dot_dnums.add_rhs_contracting_dimensions(0);
571   PrecisionConfig precision_config;
572   precision_config.mutable_operand_precision()->Resize(
573       2, PrecisionConfig::DEFAULT);
574   builder.AddInstruction(
575       HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
576   auto module = CreateNewVerifiedModule();
577   auto* computation = module->AddEntryComputation(builder.Build());
578   computation->SetExecutionThread("MainThread");
579 
580   auto options =
581       HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
582   const std::string expected_computation =
583       R"(    %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
584       %x = f32[5,10]{1,0} parameter(0)
585       %y = f32[20,10]{1,0} parameter(1)
586       %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
587       ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
588     }, execution_thread="MainThread")";
589   EXPECT_EQ(computation->ToString(options), expected_computation);
590 }
591 
TEST_F(HloComputationTest,StringificationCanonical)592 TEST_F(HloComputationTest, StringificationCanonical) {
593   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
594   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
595   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
596   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
597 
598   HloComputation::Builder builder("TransposeDot");
599   HloInstruction* x =
600       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
601   HloInstruction* y =
602       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
603   HloInstruction* reshape =
604       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
605   DotDimensionNumbers dot_dnums;
606   dot_dnums.add_lhs_contracting_dimensions(1);
607   dot_dnums.add_rhs_contracting_dimensions(0);
608   PrecisionConfig precision_config;
609   precision_config.mutable_operand_precision()->Resize(
610       2, PrecisionConfig::DEFAULT);
611   builder.AddInstruction(
612       HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
613   auto module = CreateNewVerifiedModule();
614   auto* computation = module->AddEntryComputation(builder.Build());
615   computation->SetExecutionThread("MainThread");
616 
617   auto options = HloPrintOptions().set_print_metadata(false);
618   const std::string expected_computation1 =
619       R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
620   %x = f32[5,10]{1,0} parameter(0)
621   %y = f32[20,10]{1,0} parameter(1)
622   %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
623   ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
624 }, execution_thread="MainThread")";
625   EXPECT_EQ(computation->ToString(options), expected_computation1);
626 
627   options = HloPrintOptions().Canonical();
628   const std::string expected_computation2 = R"(TransposeDot {
629   tmp_0 = f32[5,10]{1,0} parameter(0)
630   tmp_1 = f32[20,10]{1,0} parameter(1)
631   tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
632   ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
633 }, execution_thread="MainThread")";
634   EXPECT_EQ(computation->ToString(options), expected_computation2);
635 }
636 
MakeAddNComputation(int n)637 std::unique_ptr<HloComputation> MakeAddNComputation(int n) {
638   auto builder = HloComputation::Builder("add_n");
639   auto result = builder.AddInstruction(HloInstruction::CreateParameter(
640       0, ShapeUtil::MakeShape(F32, {}), "x_value"));
641   auto one = builder.AddInstruction(
642       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
643   for (int i = 0; i < n; ++i) {
644     result = builder.AddInstruction(HloInstruction::CreateBinary(
645         one->shape(), HloOpcode::kAdd, result, one));
646   }
647   return builder.Build();
648 }
649 
TEST_F(HloComputationTest,DeepEquality)650 TEST_F(HloComputationTest, DeepEquality) {
651   auto computation_a = MakeAddNComputation(200000);
652   auto computation_b = MakeAddNComputation(200000);
653   EXPECT_TRUE(*computation_a == *computation_b);
654 
655   auto computation_c = MakeAddNComputation(199999);
656   EXPECT_FALSE(*computation_a == *computation_c);
657   EXPECT_FALSE(*computation_c == *computation_b);
658 }
659 
660 // Tests that cross-module AllReduce instructions are ordered before all their
661 // predecessors and after all their successors.
TEST_F(HloComputationTest,InstructionPostOrderWithAllReduce)662 TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) {
663   const char* const hlo_string = R"(
664 HloModule Module
665 
666 add {
667   lhs = f32[] parameter(0)
668   rhs = f32[] parameter(1)
669   ROOT add = f32[] add(lhs, rhs)
670 }
671 
672 ENTRY entry {
673   param = f32[128] parameter(0), sharding={maximal device=0}
674   crs0 = f32[128] all-reduce(param),
675     replica_groups={{0}}, channel_id=1, to_apply=add,
676     sharding={maximal device=0}
677   crs1 = f32[128] all-reduce(param),
678     replica_groups={{0}}, channel_id=1, to_apply=add,
679     sharding={maximal device=1}
680   add = f32[128] add(crs0, crs0), sharding={maximal device=0}
681   ROOT t = (f32[128], f32[128]) tuple(add, crs1)
682 })";
683   TF_ASSERT_OK_AND_ASSIGN(auto module,
684                           ParseAndReturnVerifiedModule(hlo_string));
685   EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(),
686               ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(),
687                           op::Add(), op::Tuple()));
688 }
689 
690 }  // namespace
691 }  // namespace xla
692