1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/conditional_code_motion.h"
17
18 #include <iterator>
19 #include <stack>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/call_inliner.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_cse.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
39 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
40 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/platform/errors.h"
48
49 namespace xla {
50
51 namespace conditional_opt {
52
53 class BoundaryVisitor {
54 public:
55 // start with an existing conditional computation.
BoundaryVisitor(HloInstruction * conditional)56 explicit BoundaryVisitor(HloInstruction* conditional) {
57 Boundary b(Boundary::Position::kInsideBranch);
58 b.mutable_operands().push_back(conditional);
59 worklist_.push_back(b);
60 }
61 // Start with an empty work list.
BoundaryVisitor()62 BoundaryVisitor() {}
63 // Get next boundary to visit.
PopNextBoundary()64 Boundary PopNextBoundary() {
65 CHECK(!worklist_.empty());
66 Boundary b = worklist_.front();
67 worklist_.pop_front();
68 // if b is already visited, it must have multiple users and is already in
69 // new boundaries. Skip it. Only checking the first operand of b because b
70 // is expected to have at least one operand, and all the operands in b
71 // must be identical instructions from different branches for b to be moved.
72 while (!worklist_.empty() && ContainsKey(visited_, b.operands()[0])) {
73 b = worklist_.front();
74 worklist_.pop_front();
75 }
76 visited_.insert(b.operands()[0]);
77 return b;
78 }
AddToWorkList(const Boundary & b)79 void AddToWorkList(const Boundary& b) {
80 CHECK(!b.operands().empty());
81 worklist_.push_back(b);
82 }
83
HasNextBoundary()84 bool HasNextBoundary() {
85 while (!worklist_.empty()) {
86 Boundary b = worklist_.front();
87 if (!ContainsKey(visited_, b.operands()[0])) {
88 break;
89 }
90 worklist_.pop_front();
91 }
92 return !worklist_.empty();
93 }
94
95 private:
96 // worklist is the deque that contains instructions to be visited.
97 std::deque<Boundary> worklist_;
98 absl::flat_hash_set<HloInstruction*> visited_;
99 };
100
101 template <class OpCollection>
CountNonLeafOps(const OpCollection & ops)102 int64 CountNonLeafOps(const OpCollection& ops) {
103 absl::flat_hash_set<HloInstruction*> op_set;
104 for (auto op : ops) {
105 if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) {
106 op_set.insert(op);
107 }
108 }
109 return op_set.size();
110 }
111
112 // Returns estimation of potential reuses carried by a given pair of
113 // instructions. Use different integers to classify different levels
114 // of reuses This is used as a placeholder only, assuming all
115 // instructions can be fused to enable data reuses
ReusesCarriedBy(HloOpcode op,HloOpcode user)116 int64 ReusesCarriedBy(HloOpcode op, HloOpcode user) {
117 // Reuses in some way work like forces that pull instructions
118 // towards each other. We use a number 0-10 to classify how strong the force
119 // is between a pair of operations. Given a group of instructions that can be
120 // moved together, if the forces inside a conditional are stronger, the group
121 // will be moved incide or remain inside the conditional; otherwise, it will
122 // be moved outside to or remain outside of the conditional.
123 switch (user) {
124 case HloOpcode::kGetTupleElement:
125 return 0;
126 case HloOpcode::kConvert:
127 // Because convert is treated not moveable when following Dot or
128 // convolution, here if op is dot or convolution, they must be separated
129 // by a conditional boundary. Here we do not try to pull convert inside
130 // conditionals to be together with the dot or convolution.
131 switch (op) {
132 case HloOpcode::kConvolution:
133 case HloOpcode::kDot:
134 return 0;
135 default:
136 break;
137 }
138 break;
139 default:
140 break;
141 }
142 switch (op) {
143 // These instructions do not carry weight of reuse themselves.
144 case HloOpcode::kParameter:
145 case HloOpcode::kConstant:
146 case HloOpcode::kGetTupleElement:
147 return 0;
148 case HloOpcode::kConditional:
149 return 10;
150 default:
151 return -10;
152 }
153 }
154
155 // Returns true if `op` is worth hoisting.
WorthHoisting(HloOpcode op,HloOpcode child_op)156 bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
157 // TOOD[b/169182921] The following cost model is rather incomplete. Will
158 // need to extend to cover most of element-wise ops.
159 switch (op) {
160 case HloOpcode::kConvert:
161 // If Convert is after AllReduce, it is worth moving out AllReduce
162 // out of conditional for AR/CRS combine. If Convert is after other
163 // ops such as Dot or Convolutional, it is better to keep convert
164 // within conditional so that convert can be fused with Dot or
165 // Convolutional.
166 switch (child_op) {
167 case HloOpcode::kAllReduce:
168 case HloOpcode::kReshape:
169 case HloOpcode::kGetTupleElement:
170 return true;
171 default:
172 return false;
173 }
174 case HloOpcode::kGetTupleElement:
175 switch (child_op) {
176 // do not move GTE if its operand is a parameter
177 case HloOpcode::kParameter:
178 return false;
179 default:
180 return true;
181 }
182 case HloOpcode::kAllReduce:
183 case HloOpcode::kAbs:
184 case HloOpcode::kReduce:
185 case HloOpcode::kAdd:
186 case HloOpcode::kPower:
187 case HloOpcode::kCopy:
188 case HloOpcode::kConstant:
189 case HloOpcode::kSubtract:
190 case HloOpcode::kMultiply:
191 case HloOpcode::kDivide:
192 case HloOpcode::kTuple:
193 case HloOpcode::kSqrt:
194 case HloOpcode::kRsqrt:
195 case HloOpcode::kReshape:
196 case HloOpcode::kMinimum:
197 case HloOpcode::kMaximum:
198 return true;
199 default:
200 return false;
201 }
202 }
203
204 // Compare if the instructions to be visited at each branches are identical.
InstructionWithinBranchIdentical(const std::vector<HloInstruction * > & instructions,bool is_layout_sensitive)205 bool InstructionWithinBranchIdentical(
206 const std::vector<HloInstruction*>& instructions,
207 bool is_layout_sensitive) {
208 // Identical includes the shape of each operands are equal.
209 auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
210 bool eq_operands = is_layout_sensitive
211 ? ShapeUtil::Equal(a->shape(), b->shape())
212 : ShapeUtil::Compatible(a->shape(), b->shape());
213 return eq_operands;
214 };
215
216 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
217 return *a == *b;
218 };
219
220 if (instructions.empty()) {
221 return false;
222 }
223
224 if (instructions[0]->IsCrossModuleAllReduce()) {
225 return std::all_of(
226 instructions.begin(), instructions.end(),
227 [&](HloInstruction* instruction) {
228 if (!instruction->IsCrossModuleAllReduce()) {
229 return false;
230 }
231 auto old_channel_id = instruction->channel_id();
232 instruction->set_channel_id(instructions[0]->channel_id());
233 bool eq_instructions = instructions[0]->Identical(
234 *instruction, eq_operand, eq_computations, is_layout_sensitive);
235 instruction->set_channel_id(old_channel_id);
236 return eq_instructions;
237 });
238 }
239
240 return std::all_of(instructions.begin(), instructions.end(),
241 [&](HloInstruction* instruction) {
242 return instructions[0]->Identical(
243 *instruction, eq_operand, eq_computations,
244 is_layout_sensitive);
245 });
246 }
247
248 // Copy the ith instruction in boundary to outside of conditional, or do the
249 // opposite (for moving in).
CopyInOrOutOfConditional(Boundary & boundary,int64 dest_index,HloComputation * parent,absl::flat_hash_map<HloInstruction *,Boundary> & hoisted_instructions)250 Status CopyInOrOutOfConditional(
251 Boundary& boundary, int64 dest_index, HloComputation* parent,
252 absl::flat_hash_map<HloInstruction*, Boundary>& hoisted_instructions) {
253 CHECK(dest_index == 0 || boundary.IsOutsideBranch());
254 HloInstruction* op = boundary.operands()[0];
255 absl::InlinedVector<HloInstruction*, 4> new_operands;
256 for (int i = 0; i < op->operands().size(); ++i) {
257 auto op_i = op->operands()[i];
258 VLOG(2) << "Looking for " << op_i->ToString() << "\n";
259 if (ContainsKey(hoisted_instructions, op_i)) {
260 auto new_op_i =
261 FindOrDie(hoisted_instructions, op_i).operands()[dest_index];
262 VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
263 new_operands.push_back(new_op_i);
264 } else {
265 switch (op_i->opcode()) {
266 case HloOpcode::kConstant: {
267 auto new_op_i = parent->AddInstruction(op_i->Clone());
268 VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
269 new_operands.push_back(new_op_i);
270 break;
271 }
272 case HloOpcode::kGetTupleElement: {
273 auto gte = Cast<HloGetTupleElementInstruction>(op_i);
274 int64 index = gte->tuple_index();
275 HloInstruction* root = parent->root_instruction();
276 CHECK(root->opcode() == HloOpcode::kTuple &&
277 index < root->operand_count());
278 auto new_op_i = root->mutable_operand(index);
279 VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
280 new_operands.push_back(new_op_i);
281 break;
282 }
283 default:
284 LOG(FATAL) << "Unexpected out-of-boundary instruction:"
285 << op_i->ToString() << "\n";
286 }
287 }
288 }
289 HloInstruction* new_instruction = parent->AddInstruction(
290 op->CloneWithNewOperands(op->shape(), new_operands));
291 VLOG(2) << "new instruction:" << new_instruction->ToString() << "\n";
292 // Maps the instruction outside of conditional to the instruction
293 // inside of the conditional.
294 for (HloInstruction* op : boundary.operands()) {
295 Boundary b2 = ContainsKey(hoisted_instructions, op)
296 ? hoisted_instructions[op]
297 : Boundary(boundary.IsOutsideBranch()
298 ? Boundary::Position::kInsideBranch
299 : Boundary::Position::kOutsideBranch);
300 b2.mutable_operands().push_back(new_instruction);
301 hoisted_instructions[op] = b2;
302 }
303 return Status::OK();
304 }
305
306 // Identify converts to be hoisted/rematerialized out of the branch
307 // computations.
FindSpecialConverts(HloInstruction * old_root,int branch_count,HloInstruction * conditional,bool is_layout_sensitive)308 absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
309 int branch_count,
310 HloInstruction* conditional,
311 bool is_layout_sensitive) {
312 absl::flat_hash_set<int64> kspecial_convert;
313 for (int64 operand_num = 0; operand_num < old_root->operand_count();
314 ++operand_num) {
315 if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
316 continue;
317 }
318 bool replica = true;
319 HloInstruction* kspecial_convert_candidate =
320 old_root->mutable_operand(operand_num);
321 // Check whether an identical candidate appears in other branches
322 for (int others = 1; others < branch_count; ++others) {
323 HloInstruction* others_root =
324 conditional->branch_computation(others)->root_instruction();
325 bool eq_shape =
326 is_layout_sensitive
327 ? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
328 kspecial_convert_candidate->shape())
329 : ShapeUtil::Compatible(
330 others_root->operand(operand_num)->shape(),
331 kspecial_convert_candidate->shape());
332 if ((others_root->operand(operand_num)->opcode() ==
333 HloOpcode::kConvert) &&
334 eq_shape) {
335 // Nothing to be done.
336 } else {
337 replica = false;
338 break;
339 }
340 }
341 if (replica) {
342 kspecial_convert.insert(operand_num);
343 }
344 }
345 return kspecial_convert;
346 }
347
348 // Restructuring the conditional instruction as follows:
349 // i.e., %result = conditional() becomes
350 // x = conditional()
351 // y.{0..n} = gte(x, {0..n})
352 // z = tuple(y.0, y.1, ...y.n)
353 // Doing so ensures that we can accommodate the possible shape-change of the
354 // conditional when the instructions are hoisted.
RestructureConditionalInstruction(HloComputation * computation,HloInstruction * conditional)355 Status RestructureConditionalInstruction(HloComputation* computation,
356 HloInstruction* conditional) {
357 HloInstruction* old_root = computation->root_instruction();
358 std::vector<HloInstruction*> new_operands;
359 int cur_index = 0;
360 for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
361 ++cur_index) {
362 new_operands.push_back(
363 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
364 ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
365 conditional, cur_index)));
366 }
367 HloInstruction* new_tuple =
368 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
369 if (old_root == conditional) {
370 computation->set_root_instruction(new_tuple);
371 } else {
372 std::vector<HloInstruction*> new_tuple_users;
373 for (auto conditional_user : conditional->users()) {
374 auto is_new_gte = absl::c_find_if(
375 new_operands,
376 [&](HloInstruction* instr) { return instr == conditional_user; });
377 if (is_new_gte == new_operands.end()) {
378 new_tuple_users.push_back(conditional_user);
379 }
380 }
381 for (auto new_tuple_user : new_tuple_users) {
382 TF_RETURN_IF_ERROR(
383 conditional->ReplaceUseWith(new_tuple_user, new_tuple));
384 }
385 }
386 VLOG(2) << "computation after root restructure:\n" << computation->ToString();
387 return Status::OK();
388 }
389
ConvertSpecialMove(HloInstruction * conditional,bool is_layout_sensitive)390 StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
391 bool is_layout_sensitive) {
392 int branch_count = conditional->branch_count();
393 if (branch_count <= 0) {
394 return false;
395 }
396
397 // Determining whether all branch roots are tuples
398 for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
399 HloInstruction* branch_root =
400 conditional->branch_computation(branch_num)->root_instruction();
401 if (branch_root->opcode() != HloOpcode::kTuple) {
402 return false;
403 }
404 }
405
406 HloInstruction* old_root =
407 conditional->branch_computation(0)->root_instruction();
408 VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
409 // Identify the gte using `index'.
410 auto find_gte = [](const HloInstruction* conditional_result,
411 int64 index) -> HloInstruction* {
412 for (HloInstruction* instr : conditional_result->users()) {
413 if (instr->opcode() != HloOpcode::kGetTupleElement) {
414 return nullptr;
415 }
416 if (instr->tuple_index() == index) {
417 return instr;
418 }
419 }
420 return nullptr;
421 };
422
423 // Captures tuple indices refering to converts to be rematerialized/hoisted.
424 absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
425 old_root, branch_count, conditional, is_layout_sensitive);
426
427 // Exit if we cannot find any converts to be hoisted.
428 if (kspecial_convert.empty()) {
429 return false;
430 }
431
432 TF_RETURN_IF_ERROR(
433 RestructureConditionalInstruction(conditional->parent(), conditional));
434
435 for (int branch = 0; branch < branch_count; branch++) {
436 old_root = conditional->branch_computation(branch)->root_instruction();
437 absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
438 std::vector<HloInstruction*> new_operands(old_root->operand_count());
439 absl::flat_hash_set<HloInstruction*> to_hoist_set;
440
441 for (int64 operand_num = 0; operand_num < old_root->operand_count();
442 ++operand_num) {
443 map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
444 operand_num;
445 }
446 for (int64 operand_num = 0; operand_num < old_root->operand_count();
447 ++operand_num) {
448 HloInstruction* hoist = old_root->mutable_operand(operand_num);
449 if (!kspecial_convert.contains(operand_num)) {
450 new_operands[operand_num] = old_root->mutable_operand(operand_num);
451 continue;
452 }
453
454 to_hoist_set.insert(hoist);
455 int64 new_tuple_count = old_root->operand_count();
456
457 // Replace the hoisted instr in the tuple with the operand/operands.
458 // We will replace at least one of the operands of the hoist at the
459 // tuple place; the rest will be added at the end.
460 bool inplace = true;
461 CHECK(!hoist->operands().empty());
462 for (HloInstruction* prod : hoist->operands()) {
463 if (inplace) {
464 map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
465 new_operands[map_inst_to_tuple_index[hoist]] = prod;
466 inplace = false;
467 } else {
468 map_inst_to_tuple_index[prod] = new_tuple_count++;
469 new_operands.push_back(prod);
470 }
471 }
472 }
473
474 // Create the new root instruction.
475 HloComputation* cur_branch = conditional->branch_computation(branch);
476 HloInstruction* new_branch_root =
477 cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
478 // The shape can vary since the operands to convert are now
479 // being returned through the branches' root.
480 cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
481 TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
482
483 // Only one of the branches needs to change the conditional->parent().
484 if (branch != 0) {
485 continue;
486 }
487 HloComputation* conditional_parent = conditional->parent();
488 HloInstruction* newconditional =
489 conditional_parent->AddInstruction(HloInstruction::CreateConditional(
490 cur_branch->root_instruction()->shape(),
491 conditional->mutable_operand(0),
492 absl::MakeSpan(conditional->branch_computations()),
493 absl::MakeSpan(conditional->operands()).subspan(1)));
494 // Ensure that all the users of conditional refer to the new one.
495 TF_RETURN_IF_ERROR(
496 conditional->ReplaceAllUsesWithDifferentShape(newconditional));
497 TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
498 conditional = newconditional;
499 // Add the hoisted instructions in the parent.
500 for (HloInstruction* hoist : to_hoist_set) {
501 VLOG(2) << "Hoisting instruction:" << hoist->ToString();
502 int64 hoist_index = map_inst_to_tuple_index[hoist];
503 // Find out the gte that captured the hoisted instr result.
504 HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
505 CHECK(gte_hoist != nullptr);
506 std::vector<HloInstruction*> new_operands;
507 for (HloInstruction* op : hoist->operands()) {
508 HloInstruction* gte = conditional_parent->AddInstruction(
509 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
510 map_inst_to_tuple_index[op]));
511 new_operands.push_back(gte);
512 }
513 HloInstruction* hoisted = conditional_parent->AddInstruction(
514 hoist->CloneWithNewOperands(hoist->shape(), new_operands));
515 VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
516 TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
517 TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
518 }
519 // No need to explicitly delete a hoisted instruction since if its dead
520 // then the subsequent DCE will remove it.
521 }
522 VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
523 return true;
524 }
525
526 // Hoist identical ops out of the conditional. The definition of identical
527 // are the shape of the operands are identical and their properties are
528 // identical. Will start from the root instruction of each branch and get
529 // the identical ops to hoist.
MoveInstructionOut(HloInstruction * conditional,std::vector<Boundary> & to_move_out,std::vector<Boundary> & new_boundaries)530 StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
531 HloInstruction* conditional, std::vector<Boundary>& to_move_out,
532 std::vector<Boundary>& new_boundaries) {
533 if (to_move_out.empty()) {
534 return false;
535 }
536 VLOG(1) << "Modifying code--number of boundaries to move out:"
537 << to_move_out.size() << "\n";
538 HloComputation* conditional_parent = conditional->parent();
539 // save the old users before add new conditional user instructions
540 std::vector<HloInstruction*> old_conditional_users = conditional->users();
541 // Maps instructions in the conditional body to instructions hoisted outside
542 // the conditional that compute the same value.
543 absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
544 // Insert GetTupleElement before the instructions whose operands might still
545 // be within the conditional.
546 VLOG(1) << "before opt:"
547 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
548 << "\n";
549 int64 op_index = 0;
550 for (const Boundary& b : new_boundaries) {
551 HloInstruction* op = b.operands()[0];
552 CHECK(op != nullptr);
553 VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
554 HloInstruction* gtr = conditional_parent->AddInstruction(
555 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
556 op_index++));
557 Boundary b2(Boundary::Position::kOutsideBranch);
558 b2.mutable_operands().push_back(gtr);
559 hoisted_instructions[op] = b2;
560 }
561 // Copy boundary instructions out of the conditional.
562 // Visit the operands before its users and copy it, so that the copied
563 // user will point to the correct operand.
564 for (int64 i = to_move_out.size() - 1; i >= 0; i--) {
565 TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(
566 to_move_out[i], 0, conditional_parent, hoisted_instructions));
567 }
568 VLOG(2) << "Done copy branch instructions out\n"
569 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
570 << "\n";
571 // Change original users of the conditional to use the correct operands.
572 HloInstruction* old_root =
573 conditional->branch_computation(0)->root_instruction();
574 for (auto user_instr : old_conditional_users) {
575 VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
576 CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
577 auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
578 int64 index = tuple_opd->tuple_index();
579 CHECK(old_root->operands().size() > index);
580 HloInstruction* old_opd = old_root->operands()[index];
581 VLOG(2) << "old opd = " << old_opd << "\n";
582 CHECK(ContainsKey(hoisted_instructions, old_opd));
583 HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
584 CHECK(old_opd != nullptr);
585 CHECK(new_opd != nullptr);
586 VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n";
587 TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
588 TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
589 }
590 VLOG(2) << "Done changing conditional users\n"
591 << conditional_parent->ToString() << "\n";
592 // Create tuple element within each branch and set it as root.
593 int64 branch_count = conditional->branch_count();
594 for (int i = 0; i < branch_count; i++) {
595 auto computation = conditional->branch_computation(i);
596 std::vector<HloInstruction*> elements;
597 for (const auto& b1 : new_boundaries) {
598 HloInstruction* op = b1.operands()[i];
599 CHECK(op != nullptr);
600 VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
601 elements.push_back(op);
602 }
603 HloInstruction* tuple =
604 computation->AddInstruction(HloInstruction::CreateTuple(elements));
605 computation->set_root_instruction(tuple, true);
606 VLOG(2) << "computation is :" << computation->ToString() << "\n";
607 // Remove hoisted instructions from the branches.
608 for (const auto& b2 : to_move_out) {
609 auto instr_to_remove = b2.operands()[i];
610 // Double check to make sure it is safe to delete the instruction.
611 // Complications may arise due to some operations in the alternative
612 // branches (branches 1..n) being placed into the boundaries multiple
613 // times.
614 if (!computation->IsMarkedAsDead(instr_to_remove) &&
615 instr_to_remove->user_count() == 0) {
616 VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
617 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove));
618 }
619 }
620 }
621 // Change conditional instruction shape to the shape of the new root.
622 HloInstruction* new_root =
623 conditional->branch_computation(0)->root_instruction();
624 *conditional->mutable_shape() = new_root->shape();
625 VLOG(1) << "done moving instructions out of branches\n"
626 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
627 << "\n";
628 return true;
629 }
630
631 // Hoist ops from outside of the conditional to inside the branches.
MoveInstructionIn(HloInstruction * conditional,std::vector<Boundary> & to_move_in,std::vector<Boundary> & new_boundaries)632 StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
633 HloInstruction* conditional, std::vector<Boundary>& to_move_in,
634 std::vector<Boundary>& new_boundaries) {
635 if (to_move_in.empty()) {
636 return false;
637 }
638 VLOG(1) << "Modifying code---number of boundaries to move in:"
639 << to_move_in.size() << "\n";
640 VLOG(1) << "before opt:"
641 << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
642 << "\n";
643 // Mapping instructions to be moved to their new representations.
644 absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
645 int64 to_move_in_size = to_move_in.size();
646 int64 branch_count = conditional->branch_count();
647 HloGetTupleElementInstruction* tuple_use =
648 DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
649 // If use_index is -1, the old conditional root entry used by to_move_in
650 // instructions still need to be included as an entry of the modified
651 // conditional root, and the new result of the to_move_in instructions
652 // need to be added as an extra entry of the modified root; otherwise, the
653 // old root entry will be replaced with the new result in the modified root.
654 // The entry replacement should be allowed only if tuple_use has <=1 users.
655 int64 use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
656 ? tuple_use->tuple_index()
657 : -1;
658 VLOG(2) << "Tuple use index = " << use_index << "\n";
659 // Number of old conditional entries still to be used outside.
660 // If conditional shape is not tuple, will create a tuple and use subscript
661 // 0 to save the old operand being used.
662 int64 op_index =
663 conditional->shape().IsTuple()
664 ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
665 : conditional->shape().tuple_shapes_size())
666 : 0;
667 // Use to map the tuple_use instruction to its operand;
668 Boundary b_opd_use(Boundary::Position::kInsideBranch);
669 Boundary b_old_root(Boundary::Position::kInsideBranch);
670 // Create a new root instruction in each branch.
671 for (int i = 0; i < branch_count; i++) {
672 auto computation = conditional->branch_computation(i);
673 auto old_root = computation->root_instruction();
674 b_old_root.mutable_operands().push_back(old_root);
675 std::vector<HloInstruction*> operands;
676 if (old_root->opcode() == HloOpcode::kTuple) {
677 // Use operands of old_root directly, so old_root can be removed later.
678 for (int i = 0; i < old_root->operand_count(); ++i) {
679 if (i != use_index) {
680 operands.push_back(old_root->operands()[i]);
681 } else { // Map conditional use to the tuple operand.
682 b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
683 }
684 }
685 } else if (old_root->shape().IsTuple()) {
686 // If old_root is not a kTuple but has tuple shape, elements within the
687 // tuple must be extracted first to be used by the new instructions.
688 const Shape& old_shape = old_root->shape();
689 for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) {
690 auto element =
691 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
692 old_shape.tuple_shapes(i), old_root, i));
693 if (i != use_index) {
694 operands.push_back(element);
695 } else {
696 b_opd_use.mutable_operands().push_back(element);
697 }
698 }
699 } else {
700 // If old_root is not a tuple and does not have tuple shape, use it
701 // to replace the conditional directly in the new computation.
702 b_opd_use.mutable_operands().push_back(conditional);
703 }
704
705 HloInstruction* new_root =
706 computation->AddInstruction(HloInstruction::CreateTuple(operands));
707 VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
708 computation->set_root_instruction(new_root,
709 /*accept_different_shape*/ true);
710 if (old_root->opcode() == HloOpcode::kTuple) {
711 TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
712 }
713 VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
714 }
715 // Update get tuple element index of the conditional.
716 if (use_index != -1) {
717 for (auto* user : conditional->users()) {
718 if (user->opcode() == HloOpcode::kGetTupleElement &&
719 user->tuple_index() > use_index) {
720 user->set_tuple_index(user->tuple_index() - 1);
721 }
722 }
723 }
724 hoisted_instructions[conditional] = b_old_root;
725 int64 cp_start = 0;
726 if (use_index >= 0) {
727 VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
728 hoisted_instructions[tuple_use] = b_opd_use;
729 }
730 cp_start = (tuple_use != nullptr) ? 1 : 0;
731 for (int64 to_move_index = cp_start; to_move_index < to_move_in_size;
732 to_move_index++) {
733 Boundary b_to_move = to_move_in[to_move_index];
734 HloInstruction* op = b_to_move.operands()[0];
735 CHECK(op != nullptr);
736 bool to_be_used_outside = true;
737 VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
738 if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
739 op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
740 to_be_used_outside = false;
741 VLOG(2) << "Instruction is not to be used outside the branch\n";
742 }
743 Boundary b(Boundary::Position::kInsideBranch);
744 for (int i = 0; i < branch_count; i++) {
745 auto computation = conditional->branch_computation(i);
746 VLOG(2) << "Copying to branch: " << i << "\n";
747 TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation,
748 hoisted_instructions));
749 VLOG(2) << "Done:" << computation->ToString() << "\n";
750 if (to_be_used_outside) {
751 auto new_op = hoisted_instructions[op].operands()[i];
752 auto new_root = computation->root_instruction();
753 new_root->AppendOperand(new_op);
754 *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape();
755 VLOG(2) << "Extending conditional root " << i << " : "
756 << new_root->ToString() << "\n";
757 }
758 VLOG(2) << "After extending branch root: " << computation->ToString()
759 << "\n";
760 }
761 if (to_be_used_outside) {
762 // Modify uses of instructions outside of the conditionals
763 HloInstruction* gtr = conditional->parent()->AddInstruction(
764 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
765 op_index++));
766 TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
767 if (conditional->parent()->root_instruction() == op) {
768 conditional->parent()->set_root_instruction(gtr);
769 }
770 }
771 }
772 VLOG(2) << "Done copying instructions inside branch: "
773 << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n";
774 // Change conditional instruction shape to the shape of the new root.
775 HloInstruction* new_root =
776 conditional->branch_computation(0)->root_instruction();
777 *conditional->mutable_shape() = new_root->shape();
778 VLOG(2) << "Before removing instructions:"
779 << conditional->parent()->ToString() << "\n";
780 // Remove hoisted instructions from the branches.
781 for (int64 i = to_move_in_size - 1; i >= 0; i--) {
782 Boundary boundary_to_move_in = to_move_in[i];
783 HloInstruction* op = boundary_to_move_in.operands()[0];
784 if (op->user_count() == 0) {
785 VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
786 TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
787 VLOG(2) << "Done removing boundary.\n";
788 }
789 }
790
791 // Reset shapes of user gtes to the new shape.
792 if (use_index != -1) {
793 for (auto* user : conditional->users()) {
794 if (user->opcode() == HloOpcode::kGetTupleElement) {
795 VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
796 *user->mutable_shape() =
797 conditional->shape().tuple_shapes(user->tuple_index());
798 }
799 }
800 }
801 VLOG(1) << "Done moving instructions inside branches\n"
802 << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
803 << "\n";
804 return true;
805 }
806
807 // Group single chains of operands or uses of boundaries into new boundaries
808 class GroupConnectedBoundaries {
809 private:
810 std::vector<Boundary> connected_boundaries_, new_boundaries_;
811 HloInstruction* conditional_;
812 HloComputation* conditional_parent_;
813 bool is_layout_sensitive_;
814 // Instructions that have been visited but are not going to be moved.
815 absl::flat_hash_map<HloInstruction*, int>& visited_count_;
816 // The following four lines are configurations of the cost model, which will
817 // be used to determine whether to move an instruction (move_config_) and how
818 // strongly preferred it is to keep a pair of ops together (reuse_config_).
819 // The search_config_ is used to control how to navigate the search space of
820 // the cost model in the context of auto/manual tuning. The flipped array is
821 // used to save which entries in the configuration have been changed in the
822 // search/tuning process.
823 std::vector<std::vector<int64>>& move_config_;
824 std::vector<std::vector<int64>>& reuse_config_;
825 int& search_config_;
826 absl::flat_hash_map<const int64*, int64> flipped_;
827
828 // The FlipMutation function serves to implement the search of alternative
829 // cost models by deciding whether to flip a given configuration, saved in
830 // the loc parameter. The non_zero parameter provides the new value to use
831 // to flip a zero. The msg parameter is only used for debugging purpposes.
FlipMutation(int64 * loc,const int64 non_zero,const std::string & msg)832 int64 FlipMutation(int64* loc, const int64 non_zero, const std::string& msg) {
833 if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
834 VLOG(2) << "Configured not to search or loc is already flipped.";
835 return *loc;
836 }
837
838 // The 8-16 digits control the maximum number of times to flip a config.
839 int flip_count = (search_config_ >> 8) & 255;
840 if (flip_count == 0) {
841 VLOG(2) << "Maximum flip count has reached. ";
842 return *loc;
843 }
844
845 // The last 8 digits control when to start the first flip.
846 int c = search_config_ & 255;
847 VLOG(2) << "flip start index = " << c << "\n";
848 // Only flip the decision if c reaches 0.
849 if (c > 0) {
850 search_config_--;
851 return *loc;
852 }
853
854 // Decrement flip count so we can stop if it reaches 0.
855 search_config_ -= 256;
856 // Reload the 16-23 digits of the configuration, which controls how
857 // frequently a configuration should be flipped.
858 search_config_ += (search_config_ >> 16) & 255;
859 VLOG(2) << "Updating Flipping configuration = " << search_config_ << "\n";
860
861 flipped_[loc] = *loc;
862 // Copy the last 8 bits back to the first 8 bits of configuration.
863 switch (*loc) {
864 case 0:
865 *loc = non_zero;
866 break;
867 default:
868 *loc = 0;
869 break;
870 }
871 VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
872 << " to " << *loc << "\n";
873 return *loc;
874 }
875
876 public:
GroupConnectedBoundaries(HloInstruction * conditional,bool is_layout_sensitive,absl::flat_hash_map<HloInstruction *,int> & visited_count,std::vector<std::vector<int64>> * move_config,std::vector<std::vector<int64>> * reuse_config,int * search_config)877 explicit GroupConnectedBoundaries(
878 HloInstruction* conditional, bool is_layout_sensitive,
879 absl::flat_hash_map<HloInstruction*, int>& visited_count,
880 std::vector<std::vector<int64>>* move_config,
881 std::vector<std::vector<int64>>* reuse_config, int* search_config)
882 : conditional_(conditional),
883 conditional_parent_(conditional->parent()),
884 is_layout_sensitive_(is_layout_sensitive),
885 visited_count_(visited_count),
886 move_config_(*move_config),
887 reuse_config_(*reuse_config),
888 search_config_(*search_config) {}
889 // Returns estimation of potential reuses carried by a given pair of
890 // instructions. Use different integers to classify different levels
891 // of reuses. Assume all instructions can be fused to enable data reuses.
ReusesCarriedBy(HloInstruction * op,HloInstruction * user)892 int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
893 std::vector<int64>& curconfig =
894 reuse_config_[static_cast<uint32>(op->opcode())];
895 // Flip the reuse configuration if tuning the cost model.
896 // When flipping, use -10 if flipping to the default reuse model. Other
897 // values can be specified if needed to fine-control the decision making.
898 int64 config =
899 (search_config_ < 0)
900 ? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
901 HloOpcodeString(op->opcode()) + "->" +
902 HloOpcodeString(user->opcode()))
903 : curconfig[static_cast<uint32>(user->opcode())];
904 VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
905 << op->ToString() << "=>" << user->ToString() << " : " << config
906 << "\n";
907 if (config < 0) {
908 // Assume the reuse decreases with increasing user count.
909 int count1 = CountNonLeafOps(op->users());
910 int count2 = CountNonLeafOps(user->operands());
911 return (-config) / count1 / count2;
912 }
913 return config;
914 }
clear_recently_visited()915 void clear_recently_visited() {
916 for (const auto& boundary : new_boundaries_) {
917 visited_count_.erase(boundary.operands()[0]);
918 }
919 }
920 // Returns true if `instruction` is worth hoisting.
WorthHoisting(HloInstruction * instruction,bool is_inside_branch)921 bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
922 // This is needed for the "moving-in" transformation, to prevent the root
923 // of the parent computation (which contains the conditional) to be moved
924 // inside the conditional.
925 HloOpcode opcode = instruction->opcode();
926 if (opcode == HloOpcode::kTuple &&
927 instruction == conditional_parent_->root_instruction()) {
928 return false;
929 }
930 // It is not safe to move collective ops from outside to inside
931 // conditional branches, as it may cause synchronization problems,
932 // when different layouts are assigned to different branches.
933 if (opcode == HloOpcode::kAllReduce && !is_inside_branch) {
934 return false;
935 }
936
937 // It is not legal to move the parameter instructions.
938 if (opcode == HloOpcode::kParameter) {
939 return false;
940 }
941
942 // Use configuration given from outside (e.g., by autotuner).
943 std::vector<int64>& curconfig = move_config_[static_cast<uint32>(opcode)];
944 auto col = (curconfig.size() == 1) ? 0
945 : (instruction->operand_count() > 0)
946 ? static_cast<uint32>(instruction->operand(0)->opcode())
947 : 0;
948 VLOG(2) << "column = " << col << "\n";
949 VLOG(2) << "config size = " << curconfig.size() << "\n";
950 VLOG(2) << "search_config = " << search_config_ << "\n";
951 CHECK(col < curconfig.size());
952 uint32 config = (search_config_ > 0)
953 ? FlipMutation(&curconfig[col], 1,
954 "Move-" + HloOpcodeString(opcode))
955 : curconfig[col];
956 VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
957 return (config != 0);
958 }
959
ReusesBeforeBoundary(HloInstruction * user)960 int64 ReusesBeforeBoundary(HloInstruction* user) {
961 int64 reuses = 0;
962 for (auto op : user->operands()) {
963 // The operand must be an instruction that is not going to be moved (if
964 // user is inside the conditional); otherwise it must be the conditional
965 // itself and its user must be outside of the conditional.
966 if (!ContainsKey(visited_count_, op) && op != conditional_) {
967 continue;
968 }
969 if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
970 if (op->opcode() == HloOpcode::kConditional) {
971 auto tuple = op->branch_computation(0)->root_instruction();
972 if (tuple->opcode() == HloOpcode::kTuple) {
973 auto index = tuple_gte->tuple_index();
974 CHECK(index < tuple->operand_count());
975 op = tuple->mutable_operand(index);
976 }
977 }
978 reuses += ReusesCarriedBy(op, user->users()[0]);
979 } else {
980 reuses += ReusesCarriedBy(op, user);
981 }
982 }
983 VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
984 << "\n";
985 return reuses;
986 }
987
ReusesAfterBoundary(HloInstruction * user)988 int64 ReusesAfterBoundary(HloInstruction* user) {
989 CHECK(user != nullptr);
990 auto all_users = user->users();
991 // For now, assume that if an instruction has multiple-consumers, it
992 // will not be reused, as the reuse may require duplication in
993 // fusion and so is expensive. If the situation changes in the future,
994 // some aspects of the overall algorithm need to be redesigned to
995 // accommandate the change.
996 if (all_users.size() > 1) {
997 VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
998 return 0;
999 }
1000 if (!all_users.empty()) {
1001 auto op = all_users[0];
1002 int64 reuses = 0;
1003 // Only count reuses that run through the conditional root.
1004 if (op == conditional_->branch_computation(0)->root_instruction()) {
1005 int64 index = op->operand_index(user);
1006 for (auto op2 : conditional_->users()) {
1007 // If the use is not get tuple, right now do not consider it.
1008 if (op2->opcode() == HloOpcode::kGetTupleElement) {
1009 auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2);
1010 if (index == tuple_opd->tuple_index()) {
1011 all_users = op2->users();
1012 if (!all_users.empty()) {
1013 reuses += ReusesCarriedBy(user, all_users[0]);
1014 break;
1015 }
1016 }
1017 }
1018 }
1019 } else if (ContainsKey(visited_count_, op)) {
1020 reuses += ReusesCarriedBy(user, op);
1021 }
1022 VLOG(2) << "reuses after instruction " << user->ToString() << ":"
1023 << reuses << "\n";
1024 return reuses;
1025 }
1026 return 0;
1027 }
1028
BenefitForMovingBoundaries(const std::vector<Boundary> & boundaries)1029 int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
1030 int64 reuses_before = 0, reuses_after = 0;
1031 if (boundaries.size() == 1) {
1032 if (boundaries[0].IsOutsideBranch() &&
1033 boundaries[0].operands()[0]->opcode() ==
1034 HloOpcode::kGetTupleElement) {
1035 // The only boundary of moving-in is the get_tuple_element op.
1036 return -1;
1037 }
1038 if (boundaries[0].IsInsideBranch() &&
1039 boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
1040 // The only boundary of moving-out is the tuple op inside branches.
1041 return -1;
1042 }
1043 }
1044 // If trying alternative moving configurations, turn off reuse analysis.
1045 if (search_config_ > 0) {
1046 return 1;
1047 }
1048 // For cases like :
1049 // branch0 {
1050 // ROOT copy
1051 // }
1052 // branch1 {
1053 // ...
1054 // }
1055 // cond = conditional(branch0, branch1)
1056 // copy = copy(cond)
1057 //
1058 // We can fold the two copies thus reducing computation.
1059 auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64 {
1060 if (hlo->opcode() != HloOpcode::kCopy) {
1061 return 0;
1062 }
1063 const HloGetTupleElementInstruction* gte =
1064 DynCast<HloGetTupleElementInstruction>(hlo->operand(0));
1065 if (gte == nullptr) {
1066 return 0;
1067 }
1068 const HloInstruction* conditional = gte->operand(0);
1069 if (conditional != conditional_) {
1070 return 0;
1071 }
1072 int64 benefit = 0;
1073 for (auto* branch : conditional->called_computations()) {
1074 HloInstruction* root = branch->root_instruction();
1075 if (root->opcode() == HloOpcode::kTuple) {
1076 const auto* tuple_operand = root->operand(gte->tuple_index());
1077 if (tuple_operand->opcode() == HloOpcode::kCopy) {
1078 if (Shape::Equal()(tuple_operand->operand(0)->shape(),
1079 hlo->shape())) {
1080 benefit += 10;
1081 }
1082 }
1083 }
1084 }
1085 return benefit;
1086 };
1087 for (const Boundary& b : boundaries) {
1088 auto op = b.operands()[0];
1089 if (op == conditional_->branch_computation(0)->root_instruction()) {
1090 continue;
1091 }
1092 VLOG(2) << "Benefit for " << op->ToString();
1093 reuses_before += ReusesBeforeBoundary(op);
1094 VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
1095 reuses_after += ReusesAfterBoundary(op);
1096 VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
1097 }
1098
1099 int64 copy_folding_benefit = 0;
1100 if (boundaries[0].IsOutsideBranch()) {
1101 for (const Boundary& b : boundaries) {
1102 auto op = b.operands()[0];
1103 copy_folding_benefit += get_copy_folding_benefit(op);
1104 }
1105 }
1106 VLOG(2) << "Copy folding benefit: " << copy_folding_benefit;
1107
1108 if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) {
1109 return -1;
1110 } else if (boundaries[0].IsInsideBranch()) {
1111 return reuses_after - reuses_before;
1112 } else {
1113 return reuses_before - reuses_after - 1 + copy_folding_benefit;
1114 }
1115 }
1116
GetNextBoundary(const Boundary & b,int64 op_index)1117 Boundary GetNextBoundary(const Boundary& b, int64 op_index) {
1118 Boundary b2(b.GetPosition());
1119 for (int j = 0; j < b.operands().size(); ++j) {
1120 HloInstruction* inst = b.operands()[j];
1121 CHECK(inst != nullptr);
1122 HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
1123 : inst->users()[op_index];
1124 CHECK(op != nullptr);
1125 b2.mutable_operands().push_back(op);
1126 }
1127 return b2;
1128 }
1129
1130 // Checking whether it is safe to move a boundary when visited through a
1131 // dependent already considered for moving.
IsSafeToMoveBoundary(const Boundary & next_boundary)1132 bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
1133 int64 next_boundary_count =
1134 (next_boundary.IsInsideBranch())
1135 ? next_boundary.operands()[0]->user_count()
1136 : CountNonLeafOps(next_boundary.operands()[0]->operands());
1137 if (next_boundary_count <= 1) {
1138 // If boundary has only a single or no dependent, safe to move.
1139 return true;
1140 } else {
1141 if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
1142 VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
1143 << " because it has multiple dependents: "
1144 << next_boundary_count << "\n";
1145 visited_count_[next_boundary.operands()[0]] = 1;
1146 new_boundaries_.push_back(next_boundary);
1147 } else {
1148 auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
1149 next_boundary);
1150 if (pos != new_boundaries_.end() ||
1151 next_boundary.operands().size() == 1) {
1152 int count = ++visited_count_[next_boundary.operands()[0]];
1153 if (count == next_boundary_count) {
1154 VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
1155 << "\n"
1156 << " because all of its dependents have been visited: "
1157 << next_boundary_count << "\n";
1158 visited_count_.erase(next_boundary.operands()[0]);
1159 if (pos != new_boundaries_.end()) {
1160 new_boundaries_.erase(pos);
1161 }
1162 return true;
1163 }
1164 } else {
1165 VLOG(2) << "Skip incompatible multi-dependent boundary: "
1166 << next_boundary.ToString() << ":" << next_boundary_count
1167 << "\n";
1168 }
1169 }
1170 }
1171 return false;
1172 }
1173 // This function is reused both for moving the boundary outside or into a
1174 // conditional. As the result, the readability is somewhat compromised.
1175 // It might be nice to refactor this function to factor the outside-inside
1176 // considerations into separate function pointer parameters to improve
1177 // readability.
AddBoundaries(const Boundary & boundary)1178 void AddBoundaries(const Boundary& boundary) {
1179 BoundaryVisitor visitor;
1180 visitor.AddToWorkList(boundary);
1181 while (visitor.HasNextBoundary()) {
1182 Boundary b = visitor.PopNextBoundary();
1183 VLOG(2) << "visiting boundary " << b.ToString() << "\n";
1184 if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
1185 b.operands(), is_layout_sensitive_)) &&
1186 IsSafeToMoveBoundary(b) &&
1187 WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
1188 connected_boundaries_.push_back(b);
1189 VLOG(2) << "boundary can be moved\n";
1190 int64 operand_count = (b.IsInsideBranch())
1191 ? b.operands()[0]->operand_count()
1192 : b.operands()[0]->users().size();
1193 for (int i = 0; i < operand_count; i++) {
1194 Boundary next_boundary = GetNextBoundary(b, i);
1195 VLOG(2) << "Add operand/user " << i << " to visit later\n";
1196 visitor.AddToWorkList(next_boundary);
1197 }
1198 } else {
1199 VLOG(2) << "boundary cannot be moved\n";
1200 visited_count_[b.operands()[0]] = 1;
1201 new_boundaries_.push_back(b);
1202 }
1203 }
1204 }
BoundariesToMoveInOrOut(HloInstruction * conditional,const Boundary & b)1205 std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
1206 const Boundary& b) {
1207 // At the beginning of optimization, a conditional itself is added to a
1208 // worklist. Here the conditional is expanded into two sets of boundaries:
1209 // the first set contains the boundary that is inside branches and
1210 // contains the root of all branches; the second set of boundaries
1211 // contains all the users of the conditional.
1212 HloInstruction* inst = b.operands()[0];
1213 if (inst == conditional) {
1214 int branch_count = inst->branch_count();
1215 // Add conditional roots as a new boundary to visit.
1216 Boundary boundary_in(Boundary::Position::kInsideBranch);
1217 for (int i = 0; i < branch_count; i++) {
1218 HloComputation* branch_computation = inst->branch_computation(i);
1219 HloInstruction* root_inst = branch_computation->root_instruction();
1220 CHECK(root_inst != nullptr);
1221 boundary_in.mutable_operands().push_back(root_inst);
1222 }
1223 new_boundaries_.push_back(boundary_in);
1224 // Add conditional users as new boundaries to visit.
1225 for (auto u : inst->users()) {
1226 Boundary boundary_in(Boundary::Position::kOutsideBranch);
1227 boundary_in.mutable_operands().push_back(u);
1228 new_boundaries_.push_back(boundary_in);
1229 }
1230 } else {
1231 AddBoundaries(b);
1232 }
1233 return connected_boundaries_;
1234 }
AddNewBoundaries(std::vector<Boundary> & b)1235 void AddNewBoundaries(std::vector<Boundary>& b) {
1236 b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end());
1237 }
1238 };
1239
ConsiderCodeMotion(HloInstruction * conditional,const Boundary & cur_boundary,std::vector<Boundary> & to_move,std::vector<Boundary> & new_boundaries,absl::flat_hash_map<HloInstruction *,int> & visited_count)1240 ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
1241 HloInstruction* conditional, const Boundary& cur_boundary,
1242 std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
1243 absl::flat_hash_map<HloInstruction*, int>& visited_count) {
1244 GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
1245 visited_count, &move_config_, &reuse_config_,
1246 &search_config_);
1247 auto move_in_or_out =
1248 connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
1249 if (!move_in_or_out.empty()) {
1250 auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
1251 VLOG(2) << "benefit of moving in or out "
1252 << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
1253 if (benefit >= 0) {
1254 new_boundaries.clear();
1255 connect.AddNewBoundaries(new_boundaries);
1256 // The whole sequence in move_in_or_out is either all moving into a
1257 // conditional, or all moving out of a conditional. So looking only
1258 // at the first entry of the sequence is sufficient to know which
1259 // direction the move is intended.
1260 to_move = move_in_or_out;
1261 return Decision(to_move[0].IsInsideBranch()
1262 ? Decision::Direction::kMoveOutOfBranch
1263 : Decision::Direction::kMoveIntoBranch,
1264 benefit);
1265 } else {
1266 connect.clear_recently_visited();
1267 }
1268 } else {
1269 connect.AddNewBoundaries(new_boundaries);
1270 }
1271 return Decision(Decision::Direction::kNoChange, 0);
1272 }
1273
Run(HloModule * module)1274 StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
1275 VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
1276 // Use to support debugging of optimization, by disabling the opt after it has
1277 // been applied a pre-determined times (to isolate impact of transformations).
1278 if (!ConsumeFuel("conditional_code_motion", [&] {
1279 return "Skipping conditional opt after allowed limit reaching 0.\n";
1280 })) {
1281 return false;
1282 }
1283 bool changed = false;
1284 bool cleanup_changed = false;
1285 {
1286 HloPassPipeline subpipeline("before_conditional_code_motion");
1287 subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_);
1288 subpipeline.AddPass<HloDCE>();
1289 TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1290 cleanup_changed |= cleanup_changed_now;
1291 }
1292 // set the default configuration
1293 VLOG(2) << "Obtaining default configuration\n";
1294 SetDefaultMoveConfig();
1295 VLOG(2) << "Done obtaining default configuration\n";
1296 // Gather all the conditional ops in the module ahead of time, to avoid
1297 // potential complications of modifying the code that affecting traversal.
1298 std::vector<HloInstruction*> conditional_ops;
1299 // Track how many times each branch computation is shared.
1300 absl::flat_hash_map<HloComputation*, int> conditional_computations;
1301 for (auto* comp : module->MakeComputationPostOrder()) {
1302 for (auto* instr : comp->MakeInstructionPostOrder()) {
1303 if (instr->opcode() == HloOpcode::kConditional) {
1304 int branch_count = instr->branch_count();
1305 for (int i = 0; i < branch_count; ++i) {
1306 HloComputation* branch_i = instr->branch_computation(i);
1307 if (ContainsKey(conditional_computations, branch_i)) {
1308 conditional_computations[branch_i]++;
1309 } else {
1310 conditional_computations[branch_i] = 0;
1311 }
1312 }
1313 if (instr->shape().IsTuple()) {
1314 bool can_change_tuple_shape = true;
1315 for (auto user : instr->users()) {
1316 VLOG(2) << "user is : " << user->ToString() << "\n";
1317 if (user->opcode() != HloOpcode::kGetTupleElement) {
1318 can_change_tuple_shape = false;
1319 }
1320 }
1321 if (can_change_tuple_shape) {
1322 conditional_ops.push_back(instr);
1323 }
1324 } else {
1325 conditional_ops.push_back(instr);
1326 }
1327 }
1328 }
1329 }
1330
1331 // Use to collect mappings between cloned instructions.
1332 HloCloneContext clone_context(module);
1333 for (HloInstruction* conditional : conditional_ops) {
1334 int branch_count = conditional->branch_count();
1335 // check for shared conditional computations
1336 bool conditional_is_shared = false;
1337 for (int i = 0; i < branch_count; ++i) {
1338 HloComputation* branch_i = conditional->branch_computation(i);
1339 if (conditional_computations[branch_i] > 0) {
1340 conditional_is_shared = true;
1341 break;
1342 }
1343 }
1344
1345 // Boundaries to move out or to move into the branches.
1346 std::vector<std::vector<Boundary> > to_move_out, to_move_in;
1347 std::vector<std::vector<Boundary> > new_boundaries_for_moveout;
1348 std::vector<std::vector<Boundary> > new_boundaries_for_movein;
1349 // Number of times each instruction has been visited for moving.
1350 absl::flat_hash_map<HloInstruction*, int> visited_count;
1351 int benefit_move_out = 0, benefit_move_in = 0;
1352 Decision::Direction final_d = Decision::Direction::kNoChange;
1353 // The conditional is moved into a worklist as the seed (starting point).
1354 // The conditional will be expanded into multiple seeds (starting points),
1355 // its roots and its users, when it is visited by GroupConnectedBoundaries.
1356 // A NO_CHANGE decision will always be returned for the conditional itself,
1357 // so that the other seeding boundaries can be visited in turn.
1358 BoundaryVisitor visitor(conditional);
1359 VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
1360 // Try visit all the boundaries, collect the analysis results, and save
1361 // all the benefitical non-conflicting decisions. If two decisions conflict
1362 // with each other, save the more benefitical one.
1363 while (visitor.HasNextBoundary()) {
1364 std::vector<Boundary> to_move, next_boundary;
1365 Boundary boundary = visitor.PopNextBoundary();
1366 VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
1367 auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
1368 visited_count);
1369 switch (d.GetDirection()) {
1370 case Decision::Direction::kMoveOutOfBranch:
1371 VLOG(2) << "Local Decision is move out of branch\n";
1372 to_move_out.push_back(to_move);
1373 new_boundaries_for_moveout.push_back(next_boundary);
1374 benefit_move_out += d.GetBenefit();
1375 if (benefit_move_out >= benefit_move_in) {
1376 final_d = Decision::Direction::kMoveOutOfBranch;
1377 VLOG(2) << "Current Decision is move out of branch ("
1378 << to_move_out.size() << ")\n";
1379 } else {
1380 VLOG(2) << "Current Decision remains move into branch\n";
1381 }
1382 break;
1383 case Decision::Direction::kMoveIntoBranch:
1384 VLOG(2) << "Decision is move into branch\n";
1385 to_move_in.push_back(to_move);
1386 new_boundaries_for_movein.push_back(next_boundary);
1387 benefit_move_in += d.GetBenefit();
1388 if (benefit_move_out >= benefit_move_in) {
1389 VLOG(2) << "Current Decision remains move out of branch\n";
1390 } else {
1391 final_d = Decision::Direction::kMoveIntoBranch;
1392 VLOG(2) << "Current Decision is move into branch ("
1393 << to_move_in.size() << ")\n";
1394 }
1395 break;
1396 case Decision::Direction::kNoChange:
1397 VLOG(2) << "Decision is no change\n";
1398 for (const Boundary& b : next_boundary) {
1399 visitor.AddToWorkList(b);
1400 VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
1401 << "\n";
1402 }
1403 break;
1404 }
1405 }
1406 // If modification is to be made, need to clone the shared branches.
1407 if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
1408 for (int i = 0; i < branch_count; ++i) {
1409 HloComputation* branch_i = conditional->branch_computation(i);
1410 if (conditional_computations[branch_i] > 0) {
1411 // Cloning is absolutely needed if the computation is shared by
1412 // different branches, but the cloning can be potentially avoided
1413 // if the sharing is only among branches of the same conditional.
1414 // If cloning these branches causes a problem due to space issues,
1415 // a fix can pass a vector of unique branches to the actual
1416 // transformations, as an alternative representation of the
1417 // conditional branches to be modified. Right now we assume the
1418 // overhead of cloning is minimal since later stages of the compiler
1419 // inline all the computations anyway.
1420 HloComputation* clone_i =
1421 conditional->parent()->parent()->AddEmbeddedComputation(
1422 branch_i->Clone("clone", &clone_context));
1423 conditional->set_branch_computation(i, clone_i);
1424 conditional_computations[branch_i]--;
1425 // Need to translate the analysis result to generate correct result.
1426 auto update_boundary = [&](Boundary& boundary) {
1427 auto cloned_instr =
1428 clone_context.FindInstruction(boundary.operands()[i]);
1429 CHECK(cloned_instr != nullptr);
1430 VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
1431 << "\n";
1432 boundary.mutable_operands()[i] = cloned_instr;
1433 VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
1434 << "\n";
1435 };
1436 // Only boundaries to move out need to be updated.
1437 if (final_d == Decision::Direction::kMoveOutOfBranch) {
1438 for (int i = 0; i < to_move_out.size(); ++i) {
1439 std::vector<Boundary>& m = to_move_out[i];
1440 std::for_each(m.begin(), m.end(), update_boundary);
1441 }
1442 for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
1443 std::vector<Boundary>& m = new_boundaries_for_moveout[i];
1444 std::for_each(m.begin(), m.end(), update_boundary);
1445 }
1446 }
1447 }
1448 }
1449 VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
1450 << "\n";
1451 }
1452 // At most one of to_move_out or to_move_in can be non-empty, since there is
1453 // only one optimization decision.
1454 if (final_d == Decision::Direction::kMoveOutOfBranch) {
1455 CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
1456 for (int i = 0; i < to_move_out.size(); ++i) {
1457 TF_ASSIGN_OR_RETURN(bool result,
1458 MoveInstructionOut(conditional, to_move_out[i],
1459 new_boundaries_for_moveout[i]));
1460 changed |= result;
1461 }
1462 VLOG(2) << "Done moving out of branches " << to_move_out.size()
1463 << " times. \n";
1464 if (!ConsumeFuel("conditional_code_motion", [&] {
1465 return "Skipping conditional opt after allowed limit reaching 0.\n";
1466 })) {
1467 break;
1468 }
1469 } else if (final_d == Decision::Direction::kMoveIntoBranch) {
1470 CHECK(to_move_in.size() == new_boundaries_for_movein.size());
1471 for (int i = 0; i < to_move_in.size(); ++i) {
1472 TF_ASSIGN_OR_RETURN(bool result,
1473 MoveInstructionIn(conditional, to_move_in[i],
1474 new_boundaries_for_movein[i]));
1475 changed |= result;
1476 }
1477 VLOG(2) << "Done moving into branches " << to_move_in.size()
1478 << " times. \n";
1479 if (!ConsumeFuel("conditional_code_motion", [&] {
1480 return "Skipping conditional opt after allowed limit reaching 0.\n";
1481 })) {
1482 break;
1483 }
1484 } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
1485 // Invoke special handling for convert rematerialization/hoisting
1486 // We need to make sure no sharing is present in the branches because no
1487 // cloning has been done by the earlier analysis.
1488 // TOOD[b/165848866]: extend solution to handle cloning for special move.
1489 TF_ASSIGN_OR_RETURN(
1490 bool convert_result,
1491 ConvertSpecialMove(conditional, is_layout_sensitive_));
1492 if (convert_result) {
1493 VLOG(2) << "Done special moving of convert\n";
1494 if (!ConsumeFuel("conditional_code_motion", [&] {
1495 return "Skipping conditional opt after allowed limit reaching "
1496 "0.\n";
1497 })) {
1498 break;
1499 }
1500 }
1501 changed |= convert_result;
1502 }
1503 }
1504 if (changed) {
1505 HloPassPipeline subpipeline(
1506 "after_conditional_code_motion_after_convert_hoisting");
1507 VLOG(2) << "starting after motion passes: DCE\n";
1508 subpipeline.AddPass<HloDCE>();
1509 subpipeline.AddPass<TupleSimplifier>();
1510 subpipeline.AddPass<HloDCE>();
1511 TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1512 cleanup_changed |= cleanup_changed_now;
1513 }
1514 if (cleanup_changed) {
1515 VLOG(2) << "subpipeline cleanup have modified code\n";
1516 }
1517 return changed;
1518 }
1519
SetDefaultMoveConfig()1520 void ConditionalCodeMotion::SetDefaultMoveConfig() {
1521 int tuning_option = (search_config_ == 0) ? 0 : (search_config_ > 0) ? 1 : 2;
1522
1523 auto row = HloOpcodeCount();
1524 auto col = row;
1525 VLOG(2) << "Start setting default configuration\n";
1526 reuse_config_.reserve(row);
1527 move_config_.reserve(row);
1528 for (int64 opcode = 0; opcode < row; ++opcode) {
1529 // To save whether an instruction is preferred to be moved.
1530 std::vector<int64> reuse_vec(col, 0);
1531 for (uint32 j = 0; j < col; ++j) {
1532 reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
1533 static_cast<HloOpcode>(j));
1534 }
1535 reuse_config_.push_back(reuse_vec);
1536 std::vector<int64> move_vec;
1537 switch (tuning_option) {
1538 case 1:
1539 // Tuning transformation decision --- start with all yes.
1540 // Only a single entry is needed if we don't consider operands of an op
1541 // when searching/tuning transformation decisions.
1542 move_vec.push_back(1);
1543 break;
1544 case 2: // Tune the ReusesCarriedBy results only.
1545 case 0:
1546 // No tuning --- use the default configuration.
1547 // Use the opcode of first operand to configure default.
1548 move_vec.reserve(col);
1549 for (uint32 j = 0; j < col; ++j) {
1550 move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
1551 static_cast<HloOpcode>(j)));
1552 }
1553 break;
1554 }
1555 move_config_.push_back(move_vec);
1556 }
1557 }
1558
1559 } // namespace conditional_opt
1560
1561 } // namespace xla
1562