1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17
18 #include <memory>
19 #include <set>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
35
36 namespace xla {
37
38 namespace {
39
40 namespace m = match;
41 namespace op = xla::testing::opcode_matchers;
42 using ::testing::ElementsAre;
43 using ::testing::UnorderedElementsAre;
44
45 class HloComputationTest : public HloTestBase {
46 protected:
HloComputationTest()47 HloComputationTest() {}
48
49 // Create a computation which takes a scalar and returns its negation.
CreateNegateComputation()50 std::unique_ptr<HloComputation> CreateNegateComputation() {
51 auto builder = HloComputation::Builder("Negate");
52 auto param = builder.AddInstruction(
53 HloInstruction::CreateParameter(0, r0f32_, "param0"));
54 builder.AddInstruction(
55 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
56 return builder.Build();
57 }
58
59 // Creates a computation which calls map with the given computation.
CreateMapComputation(HloComputation * map_computation)60 std::unique_ptr<HloComputation> CreateMapComputation(
61 HloComputation* map_computation) {
62 auto builder = HloComputation::Builder("Map");
63 auto param = builder.AddInstruction(
64 HloInstruction::CreateParameter(0, r0f32_, "param0"));
65 builder.AddInstruction(
66 HloInstruction::CreateMap(r0f32_, {param}, map_computation));
67 return builder.Build();
68 }
69
70 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
71 };
72
TEST_F(HloComputationTest,GetEmbeddedComputationsEmpty)73 TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) {
74 auto module = CreateNewVerifiedModule();
75 auto negate_computation =
76 module->AddEntryComputation(CreateNegateComputation());
77 EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
78 }
79
TEST_F(HloComputationTest,GetEmbeddedComputationsOneComputation)80 TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) {
81 // Create computation which calls one other computation.
82 auto module = CreateNewVerifiedModule();
83 auto negate_computation =
84 module->AddEmbeddedComputation(CreateNegateComputation());
85 auto map_computation =
86 module->AddEntryComputation(CreateMapComputation(negate_computation));
87 EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
88 EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(),
89 ElementsAre(negate_computation));
90 }
91
TEST_F(HloComputationTest,GetEmbeddedComputationsDiamond)92 TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) {
93 // Create computations with a diamond-shaped callgraph.
94 auto module = CreateNewVerifiedModule();
95 auto negate_computation =
96 module->AddEmbeddedComputation(CreateNegateComputation());
97 auto map1_computation =
98 module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
99 auto map2_computation =
100 module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
101
102 auto builder = HloComputation::Builder(TestName());
103 auto param = builder.AddInstruction(
104 HloInstruction::CreateParameter(0, r0f32_, "param0"));
105 auto map1 = builder.AddInstruction(
106 HloInstruction::CreateMap(r0f32_, {param}, map1_computation));
107 auto map2 = builder.AddInstruction(
108 HloInstruction::CreateMap(r0f32_, {param}, map2_computation));
109 builder.AddInstruction(
110 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
111 auto computation = module->AddEntryComputation(builder.Build());
112
113 auto embedded_computations = computation->MakeEmbeddedComputationsList();
114 EXPECT_EQ(3, embedded_computations.size());
115 // GetEmbeddedComputations returns a post order of the embedded computations,
116 // so the negate computation must come first.
117 EXPECT_EQ(negate_computation, *embedded_computations.begin());
118 EXPECT_THAT(embedded_computations,
119 UnorderedElementsAre(negate_computation, map1_computation,
120 map2_computation));
121 }
122
TEST_F(HloComputationTest,PostOrderSingleton)123 TEST_F(HloComputationTest, PostOrderSingleton) {
124 // Test GetInstructionPostOrder for a computation with one instruction.
125 auto builder = HloComputation::Builder(TestName());
126 auto constant = builder.AddInstruction(
127 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
128 auto module = CreateNewVerifiedModule();
129 auto computation = module->AddEntryComputation(builder.Build());
130 EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
131 }
132
TEST_F(HloComputationTest,PostOrderSimple)133 TEST_F(HloComputationTest, PostOrderSimple) {
134 // Test GetInstructionPostOrder for a computation with a chain of
135 // instructions.
136 auto builder = HloComputation::Builder(TestName());
137 auto constant = builder.AddInstruction(
138 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
139 auto negate1 = builder.AddInstruction(
140 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
141 auto negate2 = builder.AddInstruction(
142 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
143 auto module = CreateNewVerifiedModule();
144 auto computation = module->AddEntryComputation(builder.Build());
145 EXPECT_THAT(computation->MakeInstructionPostOrder(),
146 ElementsAre(constant, negate1, negate2));
147 }
148
TEST_F(HloComputationTest,PostOrderDisconnectedInstructions)149 TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
150 // Test GetInstructionPostOrder for a computation with multiple instructions
151 // which are not connected.
152 auto builder = HloComputation::Builder(TestName());
153 auto constant1 = builder.AddInstruction(
154 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
155 auto constant2 = builder.AddInstruction(
156 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
157 auto constant3 = builder.AddInstruction(
158 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
159 auto constant4 = builder.AddInstruction(
160 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
161 auto module = CreateNewVerifiedModule();
162 auto computation = module->AddEntryComputation(builder.Build());
163 EXPECT_THAT(computation->MakeInstructionPostOrder(),
164 UnorderedElementsAre(constant1, constant2, constant3, constant4));
165 }
166
TEST_F(HloComputationTest,PostOrderWithMultipleRoots)167 TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
168 // Test GetInstructionPostOrder for a computation with multiple instructions
169 // which are not connected.
170 auto builder = HloComputation::Builder(TestName());
171 auto constant1 = builder.AddInstruction(
172 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
173 auto constant2 = builder.AddInstruction(
174 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
175 auto constant3 = builder.AddInstruction(
176 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
177 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
178 r0f32_, HloOpcode::kAdd, constant1, constant2));
179 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
180 r0f32_, HloOpcode::kAdd, constant2, constant3));
181 auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
182 r0f32_, HloOpcode::kAdd, constant1, constant3));
183 auto module = CreateNewVerifiedModule();
184 auto computation = module->AddEntryComputation(builder.Build());
185 auto post_order = computation->MakeInstructionPostOrder();
186 EXPECT_EQ(6, post_order.size());
187 EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3,
188 add1, add2, add3));
189 }
190
TEST_F(HloComputationTest,VisitWithMultipleRoots)191 TEST_F(HloComputationTest, VisitWithMultipleRoots) {
192 // Test that Accept visits all instructions in the computation even if the
193 // computation has multiple roots (dead code).
194 auto builder = HloComputation::Builder(TestName());
195 auto constant1 = builder.AddInstruction(
196 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
197 auto constant2 = builder.AddInstruction(
198 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
199 auto constant3 = builder.AddInstruction(
200 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
201 // Add three disconnected add expressions.
202 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
203 constant1, constant2));
204 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
205 constant2, constant3));
206 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
207 constant1, constant3));
208 auto module = CreateNewVerifiedModule();
209 auto computation = module->AddEntryComputation(builder.Build());
210 // Visitor which keeps track of which instructions have been visited.
211 class TestVisitor : public DfsHloVisitorWithDefault {
212 public:
213 explicit TestVisitor(HloComputation* computation)
214 : computation_(computation) {}
215
216 Status DefaultAction(HloInstruction* hlo_instruction) override {
217 EXPECT_FALSE(visited_set_.contains(hlo_instruction));
218 visited_set_.insert(hlo_instruction);
219 last_visited_ = hlo_instruction;
220 return OkStatus();
221 }
222
223 Status FinishVisit(HloInstruction* root) override {
224 EXPECT_EQ(computation_->root_instruction(), root);
225 ++finish_visit_calls_;
226 return OkStatus();
227 }
228
229 HloComputation* computation_;
230 absl::flat_hash_set<HloInstruction*> visited_set_;
231 int64_t finish_visit_calls_ = 0;
232 HloInstruction* last_visited_ = nullptr;
233 };
234
235 TestVisitor visitor(computation);
236 EXPECT_IS_OK(computation->Accept(&visitor));
237
238 EXPECT_EQ(6, visitor.visited_set_.size());
239 EXPECT_EQ(1, visitor.finish_visit_calls_);
240 EXPECT_EQ(computation->root_instruction(), visitor.last_visited_);
241 }
242
TEST_F(HloComputationTest,DeepCopyArray)243 TEST_F(HloComputationTest, DeepCopyArray) {
244 // Test that DeepCopyInstruction properly copies an array.
245 auto builder = HloComputation::Builder(TestName());
246 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
247 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
248 auto module = CreateNewVerifiedModule();
249 auto computation = module->AddEntryComputation(builder.Build());
250 auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
251
252 EXPECT_THAT(copy, GmockMatch(m::Copy(m::Op().Is(constant))));
253 }
254
TEST_F(HloComputationTest,DeepCopyTuple)255 TEST_F(HloComputationTest, DeepCopyTuple) {
256 // Test that DeepCopyInstruction properly copies a tuple.
257 auto builder = HloComputation::Builder(TestName());
258 auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
259 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
260 auto constant2 = builder.AddInstruction(
261 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
262 auto tuple = builder.AddInstruction(
263 HloInstruction::CreateTuple({constant1, constant2}));
264
265 auto module = CreateNewVerifiedModule();
266 auto computation = module->AddEntryComputation(builder.Build());
267 auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
268
269 EXPECT_THAT(tuple_copy, GmockMatch(m::Tuple(
270 m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
271 m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
272 EXPECT_EQ(0, tuple_copy->operand(0)->operand(0)->tuple_index());
273 EXPECT_EQ(1, tuple_copy->operand(1)->operand(0)->tuple_index());
274 }
275
TEST_F(HloComputationTest,DeepCopyArrayAtIndices)276 TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
277 // Test that DeepCopyInstruction properly handles an array when the indices to
278 // copy are specified.
279 auto builder = HloComputation::Builder(TestName());
280 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
281 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
282 auto computation = builder.Build();
283
284 {
285 // If the index is true, then a copy should be made.
286 ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/true);
287 EXPECT_THAT(computation->DeepCopyInstruction(constant, &indices_to_copy)
288 .ValueOrDie(),
289 GmockMatch(m::Copy(m::Op().Is(constant))));
290 }
291
292 {
293 // If the index is false, then no copy should be made.
294 ShapeTree<bool> indices_to_copy(constant->shape(), /*init_value=*/false);
295 EXPECT_EQ(computation->DeepCopyInstruction(constant, &indices_to_copy)
296 .ValueOrDie(),
297 constant);
298 }
299 }
300
TEST_F(HloComputationTest,DeepCopyTupleAtIndices)301 TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
302 // Test that DeepCopyInstruction properly copies elements of a tuple as
303 // specified by the given indices.
304 auto builder = HloComputation::Builder(TestName());
305 auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
306 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
307 auto constant2 = builder.AddInstruction(
308 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
309 auto tuple = builder.AddInstruction(
310 HloInstruction::CreateTuple({constant1, constant2}));
311 auto computation = builder.Build();
312
313 {
314 // All true values should copy all array elements.
315 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/true);
316 ShapeTree<HloInstruction*> copies_added(tuple->shape(),
317 /*init_value=*/nullptr);
318 HloInstruction* deep_copy =
319 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
320 .ValueOrDie();
321
322 EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
323 m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
324 .Is(copies_added.element({0})),
325 m::Copy(m::GetTupleElement(m::Op().Is(tuple)))
326 .Is(copies_added.element({1})))));
327 }
328
329 {
330 // All false elements should copy no array elements, but the GTE and tuple
331 // instruction scaffolding should be built.
332 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
333 ShapeTree<HloInstruction*> copies_added(tuple->shape(),
334 /*init_value=*/nullptr);
335 HloInstruction* deep_copy =
336 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
337 .ValueOrDie();
338
339 EXPECT_THAT(deep_copy,
340 GmockMatch(m::Tuple(m::GetTupleElement(m::Op().Is(tuple)),
341 m::GetTupleElement(m::Op().Is(tuple)))));
342 EXPECT_TRUE(copies_added.element({}) == nullptr);
343 EXPECT_TRUE(copies_added.element({0}) == nullptr);
344 EXPECT_TRUE(copies_added.element({1}) == nullptr);
345 }
346
347 {
348 // Verify one element copied, the other not.
349 ShapeTree<bool> indices_to_copy(tuple->shape(), /*init_value=*/false);
350 *indices_to_copy.mutable_element({0}) = true;
351 ShapeTree<HloInstruction*> copies_added(tuple->shape(),
352 /*init_value=*/nullptr);
353 HloInstruction* deep_copy =
354 computation->DeepCopyInstruction(tuple, &indices_to_copy, &copies_added)
355 .ValueOrDie();
356
357 EXPECT_THAT(deep_copy, GmockMatch(m::Tuple(
358 m::Copy(m::GetTupleElement(m::Op().Is(tuple))),
359 m::GetTupleElement(m::Op().Is(tuple)))));
360 EXPECT_TRUE(copies_added.element({}) == nullptr);
361 EXPECT_TRUE(copies_added.element({0}) != nullptr);
362 EXPECT_TRUE(copies_added.element({1}) == nullptr);
363 }
364 }
365
TEST_F(HloComputationTest,DeepCopyToken)366 TEST_F(HloComputationTest, DeepCopyToken) {
367 // Test that DeepCopyInstruction properly handles tokens which should not be
368 // copied.
369 auto builder = HloComputation::Builder(TestName());
370 auto token = builder.AddInstruction(HloInstruction::CreateToken());
371 auto module = CreateNewVerifiedModule();
372 auto computation = module->AddEntryComputation(builder.Build());
373 auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
374
375 // No copy should be added.
376 EXPECT_THAT(copy, GmockMatch(m::AfterAll()));
377 }
378
TEST_F(HloComputationTest,DeepCopyTokenTuple)379 TEST_F(HloComputationTest, DeepCopyTokenTuple) {
380 // Test that DeepCopyInstruction properly handles tokens which should not be
381 // copied.
382 auto builder = HloComputation::Builder(TestName());
383 auto token = builder.AddInstruction(HloInstruction::CreateToken());
384 auto constant = builder.AddInstruction(
385 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
386 auto tuple =
387 builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
388 auto module = CreateNewVerifiedModule();
389 auto computation = module->AddEntryComputation(builder.Build());
390 auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
391
392 // Only the array (second tuple element) should be copied. The token is passed
393 // through transparently.
394 EXPECT_THAT(copy, GmockMatch(m::Tuple(
395 m::GetTupleElement(m::Op().Is(tuple)),
396 m::Copy(m::GetTupleElement(m::Op().Is(tuple))))));
397 }
398
TEST_F(HloComputationTest,CycleDetection)399 TEST_F(HloComputationTest, CycleDetection) {
400 // Test whether the visitor can detect cycles in the graph.
401 auto builder = HloComputation::Builder(TestName());
402 auto constant = builder.AddInstruction(
403 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
404 auto negate = builder.AddInstruction(
405 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
406 auto add = builder.AddInstruction(
407 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate));
408 auto module = CreateNewUnverifiedModule();
409 auto computation = module->AddEntryComputation(builder.Build());
410 // Add a control dependency to create a cycle.
411 ASSERT_IS_OK(add->AddControlDependencyTo(negate));
412
413 auto instructions = computation->MakeInstructionPostOrder();
414 EXPECT_EQ(3, instructions.size());
415
416 FunctionVisitor visitor(
417 [](HloInstruction* instruction) { return OkStatus(); });
418 auto visit_status = computation->Accept(&visitor);
419 ASSERT_FALSE(visit_status.ok());
420 ASSERT_THAT(visit_status.error_message(),
421 ::testing::ContainsRegex("cycle is detecte"));
422 }
423
TEST_F(HloComputationTest,RemoveInstructionWithDuplicateOperand)424 TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
425 // Test RemoveInstructionAndUnusedOperands with an instruction which has a
426 // duplicated (dead) operand. This verifies that the operand is not deleted
427 // twice.
428 auto builder = HloComputation::Builder(TestName());
429 auto constant = builder.AddInstruction(
430 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
431 auto dead_negate = builder.AddInstruction(
432 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
433 auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
434 r0f32_, HloOpcode::kAdd, dead_negate, dead_negate));
435 auto negate = builder.AddInstruction(
436 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
437 auto module = CreateNewVerifiedModule();
438 auto computation = module->AddEntryComputation(builder.Build());
439 EXPECT_EQ(4, computation->instruction_count());
440 EXPECT_THAT(computation->root_instruction(),
441 GmockMatch(m::Negate(m::Op().Is(constant))));
442 EXPECT_EQ(negate, computation->root_instruction());
443
444 ASSERT_IS_OK(computation->RemoveInstructionAndUnusedOperands(dead_add));
445
446 EXPECT_EQ(2, computation->instruction_count());
447 EXPECT_THAT(computation->root_instruction(),
448 GmockMatch(m::Negate(m::Op().Is(constant))));
449 EXPECT_EQ(negate, computation->root_instruction());
450 }
451
TEST_F(HloComputationTest,CloneWithControlDependency)452 TEST_F(HloComputationTest, CloneWithControlDependency) {
453 auto builder = HloComputation::Builder(TestName());
454 auto constant1 = builder.AddInstruction(
455 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
456 auto constant2 = builder.AddInstruction(
457 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
458 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
459 r0f32_, HloOpcode::kAdd, constant1, constant2));
460
461 auto param = builder.AddInstruction(
462 HloInstruction::CreateParameter(0, r0f32_, "param0"));
463 auto negate = builder.AddInstruction(
464 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
465 auto module = CreateNewVerifiedModule();
466 auto computation =
467 module->AddEntryComputation(builder.Build(/*root_instruction=*/add));
468
469 TF_CHECK_OK(negate->AddControlDependencyTo(add));
470
471 auto clone = computation->Clone();
472
473 auto cloned_add = clone->root_instruction();
474 EXPECT_EQ(cloned_add->opcode(), HloOpcode::kAdd);
475
476 auto predecessors = cloned_add->control_predecessors();
477 EXPECT_EQ(1, predecessors.size());
478 EXPECT_EQ(HloOpcode::kNegate, predecessors[0]->opcode());
479 auto successors = predecessors[0]->control_successors();
480 EXPECT_THAT(successors, ::testing::ElementsAre(cloned_add));
481 }
482
TEST_F(HloComputationTest,CloneWithReplacements)483 TEST_F(HloComputationTest, CloneWithReplacements) {
484 auto builder = HloComputation::Builder(TestName());
485 Shape r0s64 = ShapeUtil::MakeShape(S64, {});
486 Shape r0s32 = ShapeUtil::MakeShape(S32, {});
487 Shape r0u32 = ShapeUtil::MakeShape(U32, {});
488 auto param0 = builder.AddInstruction(
489 HloInstruction::CreateParameter(0, r0f32_, "p.0.lhs"));
490 auto param1 = builder.AddInstruction(
491 HloInstruction::CreateParameter(1, r0f32_, "p.0.rhs"));
492 auto param2 =
493 builder.AddInstruction(HloInstruction::CreateParameter(2, r0s64, "p.1"));
494 auto lt = builder.AddInstruction(
495 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0,
496 param1, ComparisonDirection::kLt));
497 auto module = CreateNewVerifiedModule();
498 auto computation =
499 module->AddEntryComputation(builder.Build(/*root_instruction=*/lt));
500 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
501 replacements;
502 replacements.emplace(param2,
503 HloInstruction::CreateParameter(2, r0s32, "p.1"));
504 auto param3 = HloInstruction::CreateParameter(3, r0u32, "p.2");
505 std::vector<const HloInstruction*> extra_parameters{param3.get()};
506 auto clone =
507 computation->CloneWithReplacements(&replacements, extra_parameters);
508 ASSERT_EQ(clone->num_parameters(), 4);
509 EXPECT_TRUE(
510 ShapeUtil::Equal(clone->parameter_instruction(0)->shape(), r0f32_));
511 EXPECT_TRUE(
512 ShapeUtil::Equal(clone->parameter_instruction(1)->shape(), r0f32_));
513 EXPECT_TRUE(
514 ShapeUtil::Equal(clone->parameter_instruction(2)->shape(), r0s32));
515 EXPECT_TRUE(
516 ShapeUtil::Equal(clone->parameter_instruction(3)->shape(), r0u32));
517 }
518
TEST_F(HloComputationTest,Stringification)519 TEST_F(HloComputationTest, Stringification) {
520 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
521 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
522 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
523 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
524
525 HloComputation::Builder builder("TransposeDot");
526 HloInstruction* x =
527 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
528 HloInstruction* y =
529 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
530 HloInstruction* reshape =
531 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
532 DotDimensionNumbers dot_dnums;
533 dot_dnums.add_lhs_contracting_dimensions(1);
534 dot_dnums.add_rhs_contracting_dimensions(0);
535 PrecisionConfig precision_config;
536 precision_config.mutable_operand_precision()->Resize(
537 2, PrecisionConfig::DEFAULT);
538 builder.AddInstruction(
539 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
540 auto module = CreateNewVerifiedModule();
541 auto* computation = module->AddEntryComputation(builder.Build());
542 computation->SetExecutionThread("MainThread");
543
544 auto options = HloPrintOptions().set_print_metadata(false);
545 const std::string expected_computation =
546 R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
547 %x = f32[5,10]{1,0} parameter(0)
548 %y = f32[20,10]{1,0} parameter(1)
549 %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
550 ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
551 }, execution_thread="MainThread")";
552 EXPECT_EQ(computation->ToString(options), expected_computation);
553 }
554
TEST_F(HloComputationTest,StringificationIndent)555 TEST_F(HloComputationTest, StringificationIndent) {
556 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
557 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
558 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
559 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
560
561 HloComputation::Builder builder("TransposeDot");
562 HloInstruction* x =
563 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
564 HloInstruction* y =
565 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
566 HloInstruction* reshape =
567 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
568 DotDimensionNumbers dot_dnums;
569 dot_dnums.add_lhs_contracting_dimensions(1);
570 dot_dnums.add_rhs_contracting_dimensions(0);
571 PrecisionConfig precision_config;
572 precision_config.mutable_operand_precision()->Resize(
573 2, PrecisionConfig::DEFAULT);
574 builder.AddInstruction(
575 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
576 auto module = CreateNewVerifiedModule();
577 auto* computation = module->AddEntryComputation(builder.Build());
578 computation->SetExecutionThread("MainThread");
579
580 auto options =
581 HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
582 const std::string expected_computation =
583 R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
584 %x = f32[5,10]{1,0} parameter(0)
585 %y = f32[20,10]{1,0} parameter(1)
586 %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
587 ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
588 }, execution_thread="MainThread")";
589 EXPECT_EQ(computation->ToString(options), expected_computation);
590 }
591
TEST_F(HloComputationTest,StringificationCanonical)592 TEST_F(HloComputationTest, StringificationCanonical) {
593 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
594 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
595 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
596 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
597
598 HloComputation::Builder builder("TransposeDot");
599 HloInstruction* x =
600 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
601 HloInstruction* y =
602 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
603 HloInstruction* reshape =
604 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
605 DotDimensionNumbers dot_dnums;
606 dot_dnums.add_lhs_contracting_dimensions(1);
607 dot_dnums.add_rhs_contracting_dimensions(0);
608 PrecisionConfig precision_config;
609 precision_config.mutable_operand_precision()->Resize(
610 2, PrecisionConfig::DEFAULT);
611 builder.AddInstruction(
612 HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
613 auto module = CreateNewVerifiedModule();
614 auto* computation = module->AddEntryComputation(builder.Build());
615 computation->SetExecutionThread("MainThread");
616
617 auto options = HloPrintOptions().set_print_metadata(false);
618 const std::string expected_computation1 =
619 R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
620 %x = f32[5,10]{1,0} parameter(0)
621 %y = f32[20,10]{1,0} parameter(1)
622 %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
623 ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
624 }, execution_thread="MainThread")";
625 EXPECT_EQ(computation->ToString(options), expected_computation1);
626
627 options = HloPrintOptions().Canonical();
628 const std::string expected_computation2 = R"(TransposeDot {
629 tmp_0 = f32[5,10]{1,0} parameter(0)
630 tmp_1 = f32[20,10]{1,0} parameter(1)
631 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
632 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
633 }, execution_thread="MainThread")";
634 EXPECT_EQ(computation->ToString(options), expected_computation2);
635 }
636
MakeAddNComputation(int n)637 std::unique_ptr<HloComputation> MakeAddNComputation(int n) {
638 auto builder = HloComputation::Builder("add_n");
639 auto result = builder.AddInstruction(HloInstruction::CreateParameter(
640 0, ShapeUtil::MakeShape(F32, {}), "x_value"));
641 auto one = builder.AddInstruction(
642 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
643 for (int i = 0; i < n; ++i) {
644 result = builder.AddInstruction(HloInstruction::CreateBinary(
645 one->shape(), HloOpcode::kAdd, result, one));
646 }
647 return builder.Build();
648 }
649
TEST_F(HloComputationTest,DeepEquality)650 TEST_F(HloComputationTest, DeepEquality) {
651 auto computation_a = MakeAddNComputation(200000);
652 auto computation_b = MakeAddNComputation(200000);
653 EXPECT_TRUE(*computation_a == *computation_b);
654
655 auto computation_c = MakeAddNComputation(199999);
656 EXPECT_FALSE(*computation_a == *computation_c);
657 EXPECT_FALSE(*computation_c == *computation_b);
658 }
659
660 // Tests that cross-module AllReduce instructions are ordered before all their
661 // predecessors and after all their successors.
TEST_F(HloComputationTest,InstructionPostOrderWithAllReduce)662 TEST_F(HloComputationTest, InstructionPostOrderWithAllReduce) {
663 const char* const hlo_string = R"(
664 HloModule Module
665
666 add {
667 lhs = f32[] parameter(0)
668 rhs = f32[] parameter(1)
669 ROOT add = f32[] add(lhs, rhs)
670 }
671
672 ENTRY entry {
673 param = f32[128] parameter(0), sharding={maximal device=0}
674 crs0 = f32[128] all-reduce(param),
675 replica_groups={{0}}, channel_id=1, to_apply=add,
676 sharding={maximal device=0}
677 crs1 = f32[128] all-reduce(param),
678 replica_groups={{0}}, channel_id=1, to_apply=add,
679 sharding={maximal device=1}
680 add = f32[128] add(crs0, crs0), sharding={maximal device=0}
681 ROOT t = (f32[128], f32[128]) tuple(add, crs1)
682 })";
683 TF_ASSERT_OK_AND_ASSIGN(auto module,
684 ParseAndReturnVerifiedModule(hlo_string));
685 EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(),
686 ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(),
687 op::Add(), op::Tuple()));
688 }
689
690 } // namespace
691 } // namespace xla
692