• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/conditional_code_motion.h"
17 
18 #include <iterator>
19 #include <stack>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/call_graph.h"
30 #include "tensorflow/compiler/xla/service/call_inliner.h"
31 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_cse.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
39 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
40 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
41 #include "tensorflow/compiler/xla/shape_util.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/platform/errors.h"
48 
49 namespace xla {
50 
51 namespace conditional_opt {
52 
53 class BoundaryVisitor {
54  public:
55   // start with an existing conditional computation.
BoundaryVisitor(HloInstruction * conditional)56   explicit BoundaryVisitor(HloInstruction* conditional) {
57     Boundary b(Boundary::Position::kInsideBranch);
58     b.mutable_operands().push_back(conditional);
59     worklist_.push_back(b);
60   }
61   // Start with an empty work list.
BoundaryVisitor()62   BoundaryVisitor() {}
63   // Get next boundary to visit.
PopNextBoundary()64   Boundary PopNextBoundary() {
65     CHECK(!worklist_.empty());
66     Boundary b = worklist_.front();
67     worklist_.pop_front();
68     // if b is already visited, it must have multiple users and is already in
69     // new boundaries. Skip it. Only checking the first operand of b because b
70     // is expected to have at least one operand, and all the operands in b
71     // must be identical instructions from different branches for b to be moved.
72     while (!worklist_.empty() && ContainsKey(visited_, b.operands()[0])) {
73       b = worklist_.front();
74       worklist_.pop_front();
75     }
76     visited_.insert(b.operands()[0]);
77     return b;
78   }
AddToWorkList(const Boundary & b)79   void AddToWorkList(const Boundary& b) {
80     CHECK(!b.operands().empty());
81     worklist_.push_back(b);
82   }
83 
HasNextBoundary()84   bool HasNextBoundary() {
85     while (!worklist_.empty()) {
86       Boundary b = worklist_.front();
87       if (!ContainsKey(visited_, b.operands()[0])) {
88         break;
89       }
90       worklist_.pop_front();
91     }
92     return !worklist_.empty();
93   }
94 
95  private:
96   // worklist is the deque that contains instructions to be visited.
97   std::deque<Boundary> worklist_;
98   absl::flat_hash_set<HloInstruction*> visited_;
99 };
100 
101 template <class OpCollection>
CountNonLeafOps(const OpCollection & ops)102 int64 CountNonLeafOps(const OpCollection& ops) {
103   absl::flat_hash_set<HloInstruction*> op_set;
104   for (auto op : ops) {
105     if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) {
106       op_set.insert(op);
107     }
108   }
109   return op_set.size();
110 }
111 
112 // Returns estimation of potential reuses carried by a given pair of
113 // instructions.  Use different integers to classify different levels
114 // of reuses This is used as a placeholder only, assuming all
115 // instructions can be fused to enable data reuses
ReusesCarriedBy(HloOpcode op,HloOpcode user)116 int64 ReusesCarriedBy(HloOpcode op, HloOpcode user) {
117   // Reuses in some way work like forces that pull instructions
118   // towards each other. We use a number 0-10 to classify how strong the force
119   // is between a pair of operations. Given a group of instructions that can be
120   // moved together, if the forces inside a conditional are stronger, the group
121   // will be moved incide or remain inside the conditional; otherwise, it will
122   // be moved outside to or remain outside of the conditional.
123   switch (user) {
124     case HloOpcode::kGetTupleElement:
125       return 0;
126     case HloOpcode::kConvert:
127       // Because convert is treated not moveable when following Dot or
128       // convolution, here if op is dot or convolution, they must be separated
129       // by a conditional boundary. Here we do not try to pull convert inside
130       // conditionals to be together with the dot or convolution.
131       switch (op) {
132         case HloOpcode::kConvolution:
133         case HloOpcode::kDot:
134           return 0;
135         default:
136           break;
137       }
138       break;
139     default:
140       break;
141   }
142   switch (op) {
143       // These instructions do not carry weight of reuse themselves.
144     case HloOpcode::kParameter:
145     case HloOpcode::kConstant:
146     case HloOpcode::kGetTupleElement:
147       return 0;
148     case HloOpcode::kConditional:
149       return 10;
150     default:
151       return -10;
152   }
153 }
154 
155 // Returns true if `op` is worth hoisting.
WorthHoisting(HloOpcode op,HloOpcode child_op)156 bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
157   // TOOD[b/169182921] The following cost model is rather incomplete. Will
158   // need to extend to cover most of element-wise ops.
159   switch (op) {
160     case HloOpcode::kConvert:
161       // If Convert is after AllReduce, it is worth moving out AllReduce
162       // out of conditional for AR/CRS combine. If Convert is after other
163       // ops such as Dot or Convolutional, it is better to keep convert
164       // within conditional so that convert can be fused with Dot or
165       // Convolutional.
166       switch (child_op) {
167         case HloOpcode::kAllReduce:
168         case HloOpcode::kReshape:
169         case HloOpcode::kGetTupleElement:
170           return true;
171         default:
172           return false;
173       }
174     case HloOpcode::kGetTupleElement:
175       switch (child_op) {
176         // do not move GTE if its operand is a parameter
177         case HloOpcode::kParameter:
178           return false;
179         default:
180           return true;
181       }
182     case HloOpcode::kAllReduce:
183     case HloOpcode::kAbs:
184     case HloOpcode::kReduce:
185     case HloOpcode::kAdd:
186     case HloOpcode::kPower:
187     case HloOpcode::kCopy:
188     case HloOpcode::kConstant:
189     case HloOpcode::kSubtract:
190     case HloOpcode::kMultiply:
191     case HloOpcode::kDivide:
192     case HloOpcode::kTuple:
193     case HloOpcode::kSqrt:
194     case HloOpcode::kRsqrt:
195     case HloOpcode::kReshape:
196     case HloOpcode::kMinimum:
197     case HloOpcode::kMaximum:
198       return true;
199     default:
200       return false;
201   }
202 }
203 
204 // Compare if the instructions to be visited at each branches are identical.
InstructionWithinBranchIdentical(const std::vector<HloInstruction * > & instructions,bool is_layout_sensitive)205 bool InstructionWithinBranchIdentical(
206     const std::vector<HloInstruction*>& instructions,
207     bool is_layout_sensitive) {
208   // Identical includes the shape of each operands are equal.
209   auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
210     bool eq_operands = is_layout_sensitive
211                            ? ShapeUtil::Equal(a->shape(), b->shape())
212                            : ShapeUtil::Compatible(a->shape(), b->shape());
213     return eq_operands;
214   };
215 
216   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
217     return *a == *b;
218   };
219 
220   if (instructions.empty()) {
221     return false;
222   }
223 
224   if (instructions[0]->IsCrossModuleAllReduce()) {
225     return std::all_of(
226         instructions.begin(), instructions.end(),
227         [&](HloInstruction* instruction) {
228           if (!instruction->IsCrossModuleAllReduce()) {
229             return false;
230           }
231           auto old_channel_id = instruction->channel_id();
232           instruction->set_channel_id(instructions[0]->channel_id());
233           bool eq_instructions = instructions[0]->Identical(
234               *instruction, eq_operand, eq_computations, is_layout_sensitive);
235           instruction->set_channel_id(old_channel_id);
236           return eq_instructions;
237         });
238   }
239 
240   return std::all_of(instructions.begin(), instructions.end(),
241                      [&](HloInstruction* instruction) {
242                        return instructions[0]->Identical(
243                            *instruction, eq_operand, eq_computations,
244                            is_layout_sensitive);
245                      });
246 }
247 
248 // Copy the ith instruction in boundary to outside of conditional, or do the
249 // opposite (for moving in).
CopyInOrOutOfConditional(Boundary & boundary,int64 dest_index,HloComputation * parent,absl::flat_hash_map<HloInstruction *,Boundary> & hoisted_instructions)250 Status CopyInOrOutOfConditional(
251     Boundary& boundary, int64 dest_index, HloComputation* parent,
252     absl::flat_hash_map<HloInstruction*, Boundary>& hoisted_instructions) {
253   CHECK(dest_index == 0 || boundary.IsOutsideBranch());
254   HloInstruction* op = boundary.operands()[0];
255   absl::InlinedVector<HloInstruction*, 4> new_operands;
256   for (int i = 0; i < op->operands().size(); ++i) {
257     auto op_i = op->operands()[i];
258     VLOG(2) << "Looking for " << op_i->ToString() << "\n";
259     if (ContainsKey(hoisted_instructions, op_i)) {
260       auto new_op_i =
261           FindOrDie(hoisted_instructions, op_i).operands()[dest_index];
262       VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
263       new_operands.push_back(new_op_i);
264     } else {
265       switch (op_i->opcode()) {
266         case HloOpcode::kConstant: {
267           auto new_op_i = parent->AddInstruction(op_i->Clone());
268           VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
269           new_operands.push_back(new_op_i);
270           break;
271         }
272         case HloOpcode::kGetTupleElement: {
273           auto gte = Cast<HloGetTupleElementInstruction>(op_i);
274           int64 index = gte->tuple_index();
275           HloInstruction* root = parent->root_instruction();
276           CHECK(root->opcode() == HloOpcode::kTuple &&
277                 index < root->operand_count());
278           auto new_op_i = root->mutable_operand(index);
279           VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
280           new_operands.push_back(new_op_i);
281           break;
282         }
283         default:
284           LOG(FATAL) << "Unexpected out-of-boundary instruction:"
285                      << op_i->ToString() << "\n";
286       }
287     }
288   }
289   HloInstruction* new_instruction = parent->AddInstruction(
290       op->CloneWithNewOperands(op->shape(), new_operands));
291   VLOG(2) << "new instruction:" << new_instruction->ToString() << "\n";
292   // Maps the instruction outside of conditional to the instruction
293   // inside of the conditional.
294   for (HloInstruction* op : boundary.operands()) {
295     Boundary b2 = ContainsKey(hoisted_instructions, op)
296                       ? hoisted_instructions[op]
297                       : Boundary(boundary.IsOutsideBranch()
298                                      ? Boundary::Position::kInsideBranch
299                                      : Boundary::Position::kOutsideBranch);
300     b2.mutable_operands().push_back(new_instruction);
301     hoisted_instructions[op] = b2;
302   }
303   return Status::OK();
304 }
305 
306 // Identify converts to be hoisted/rematerialized out of the branch
307 // computations.
FindSpecialConverts(HloInstruction * old_root,int branch_count,HloInstruction * conditional,bool is_layout_sensitive)308 absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
309                                                int branch_count,
310                                                HloInstruction* conditional,
311                                                bool is_layout_sensitive) {
312   absl::flat_hash_set<int64> kspecial_convert;
313   for (int64 operand_num = 0; operand_num < old_root->operand_count();
314        ++operand_num) {
315     if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
316       continue;
317     }
318     bool replica = true;
319     HloInstruction* kspecial_convert_candidate =
320         old_root->mutable_operand(operand_num);
321     // Check whether an identical candidate appears in other branches
322     for (int others = 1; others < branch_count; ++others) {
323       HloInstruction* others_root =
324           conditional->branch_computation(others)->root_instruction();
325       bool eq_shape =
326           is_layout_sensitive
327               ? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
328                                  kspecial_convert_candidate->shape())
329               : ShapeUtil::Compatible(
330                     others_root->operand(operand_num)->shape(),
331                     kspecial_convert_candidate->shape());
332       if ((others_root->operand(operand_num)->opcode() ==
333            HloOpcode::kConvert) &&
334           eq_shape) {
335         // Nothing to be done.
336       } else {
337         replica = false;
338         break;
339       }
340     }
341     if (replica) {
342       kspecial_convert.insert(operand_num);
343     }
344   }
345   return kspecial_convert;
346 }
347 
348 // Restructuring the conditional instruction as follows:
349 // i.e., %result = conditional() becomes
350 // x = conditional()
351 // y.{0..n} = gte(x, {0..n})
352 // z = tuple(y.0, y.1, ...y.n)
353 // Doing so ensures that we can accommodate the possible shape-change of the
354 // conditional when the instructions are hoisted.
RestructureConditionalInstruction(HloComputation * computation,HloInstruction * conditional)355 Status RestructureConditionalInstruction(HloComputation* computation,
356                                          HloInstruction* conditional) {
357   HloInstruction* old_root = computation->root_instruction();
358   std::vector<HloInstruction*> new_operands;
359   int cur_index = 0;
360   for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
361        ++cur_index) {
362     new_operands.push_back(
363         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
364             ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
365             conditional, cur_index)));
366   }
367   HloInstruction* new_tuple =
368       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
369   if (old_root == conditional) {
370     computation->set_root_instruction(new_tuple);
371   } else {
372     std::vector<HloInstruction*> new_tuple_users;
373     for (auto conditional_user : conditional->users()) {
374       auto is_new_gte = absl::c_find_if(
375           new_operands,
376           [&](HloInstruction* instr) { return instr == conditional_user; });
377       if (is_new_gte == new_operands.end()) {
378         new_tuple_users.push_back(conditional_user);
379       }
380     }
381     for (auto new_tuple_user : new_tuple_users) {
382       TF_RETURN_IF_ERROR(
383           conditional->ReplaceUseWith(new_tuple_user, new_tuple));
384     }
385   }
386   VLOG(2) << "computation after root restructure:\n" << computation->ToString();
387   return Status::OK();
388 }
389 
ConvertSpecialMove(HloInstruction * conditional,bool is_layout_sensitive)390 StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
391                                   bool is_layout_sensitive) {
392   int branch_count = conditional->branch_count();
393   if (branch_count <= 0) {
394     return false;
395   }
396 
397   // Determining whether all branch roots are tuples
398   for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
399     HloInstruction* branch_root =
400         conditional->branch_computation(branch_num)->root_instruction();
401     if (branch_root->opcode() != HloOpcode::kTuple) {
402       return false;
403     }
404   }
405 
406   HloInstruction* old_root =
407       conditional->branch_computation(0)->root_instruction();
408   VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
409   // Identify the gte using `index'.
410   auto find_gte = [](const HloInstruction* conditional_result,
411                      int64 index) -> HloInstruction* {
412     for (HloInstruction* instr : conditional_result->users()) {
413       if (instr->opcode() != HloOpcode::kGetTupleElement) {
414         return nullptr;
415       }
416       if (instr->tuple_index() == index) {
417         return instr;
418       }
419     }
420     return nullptr;
421   };
422 
423   // Captures tuple indices refering to converts to be rematerialized/hoisted.
424   absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
425       old_root, branch_count, conditional, is_layout_sensitive);
426 
427   // Exit if we cannot find any converts to be hoisted.
428   if (kspecial_convert.empty()) {
429     return false;
430   }
431 
432   TF_RETURN_IF_ERROR(
433       RestructureConditionalInstruction(conditional->parent(), conditional));
434 
435   for (int branch = 0; branch < branch_count; branch++) {
436     old_root = conditional->branch_computation(branch)->root_instruction();
437     absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
438     std::vector<HloInstruction*> new_operands(old_root->operand_count());
439     absl::flat_hash_set<HloInstruction*> to_hoist_set;
440 
441     for (int64 operand_num = 0; operand_num < old_root->operand_count();
442          ++operand_num) {
443       map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
444           operand_num;
445     }
446     for (int64 operand_num = 0; operand_num < old_root->operand_count();
447          ++operand_num) {
448       HloInstruction* hoist = old_root->mutable_operand(operand_num);
449       if (!kspecial_convert.contains(operand_num)) {
450         new_operands[operand_num] = old_root->mutable_operand(operand_num);
451         continue;
452       }
453 
454       to_hoist_set.insert(hoist);
455       int64 new_tuple_count = old_root->operand_count();
456 
457       // Replace the hoisted instr in the tuple with the operand/operands.
458       // We will replace at least one of the operands of the hoist at the
459       // tuple place; the rest will be added at the end.
460       bool inplace = true;
461       CHECK(!hoist->operands().empty());
462       for (HloInstruction* prod : hoist->operands()) {
463         if (inplace) {
464           map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
465           new_operands[map_inst_to_tuple_index[hoist]] = prod;
466           inplace = false;
467         } else {
468           map_inst_to_tuple_index[prod] = new_tuple_count++;
469           new_operands.push_back(prod);
470         }
471       }
472     }
473 
474     // Create the new root instruction.
475     HloComputation* cur_branch = conditional->branch_computation(branch);
476     HloInstruction* new_branch_root =
477         cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
478     // The shape can vary since the operands to convert are now
479     // being returned through the branches' root.
480     cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
481     TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
482 
483     // Only one of the branches needs to change the conditional->parent().
484     if (branch != 0) {
485       continue;
486     }
487     HloComputation* conditional_parent = conditional->parent();
488     HloInstruction* newconditional =
489         conditional_parent->AddInstruction(HloInstruction::CreateConditional(
490             cur_branch->root_instruction()->shape(),
491             conditional->mutable_operand(0),
492             absl::MakeSpan(conditional->branch_computations()),
493             absl::MakeSpan(conditional->operands()).subspan(1)));
494     // Ensure that all the users of conditional refer to the new one.
495     TF_RETURN_IF_ERROR(
496         conditional->ReplaceAllUsesWithDifferentShape(newconditional));
497     TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
498     conditional = newconditional;
499     // Add the hoisted instructions in the parent.
500     for (HloInstruction* hoist : to_hoist_set) {
501       VLOG(2) << "Hoisting instruction:" << hoist->ToString();
502       int64 hoist_index = map_inst_to_tuple_index[hoist];
503       // Find out the gte that captured the hoisted instr result.
504       HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
505       CHECK(gte_hoist != nullptr);
506       std::vector<HloInstruction*> new_operands;
507       for (HloInstruction* op : hoist->operands()) {
508         HloInstruction* gte = conditional_parent->AddInstruction(
509             HloInstruction::CreateGetTupleElement(op->shape(), conditional,
510                                                   map_inst_to_tuple_index[op]));
511         new_operands.push_back(gte);
512       }
513       HloInstruction* hoisted = conditional_parent->AddInstruction(
514           hoist->CloneWithNewOperands(hoist->shape(), new_operands));
515       VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
516       TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
517       TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
518     }
519     // No need to explicitly delete a hoisted instruction since if its dead
520     // then the subsequent DCE will remove it.
521   }
522   VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
523   return true;
524 }
525 
526 // Hoist identical ops out of the conditional. The definition of identical
527 // are the shape of the operands are identical and their properties are
528 // identical. Will start from the root instruction of each branch and get
529 // the identical ops to hoist.
MoveInstructionOut(HloInstruction * conditional,std::vector<Boundary> & to_move_out,std::vector<Boundary> & new_boundaries)530 StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
531     HloInstruction* conditional, std::vector<Boundary>& to_move_out,
532     std::vector<Boundary>& new_boundaries) {
533   if (to_move_out.empty()) {
534     return false;
535   }
536   VLOG(1) << "Modifying code--number of boundaries to move out:"
537           << to_move_out.size() << "\n";
538   HloComputation* conditional_parent = conditional->parent();
539   // save the old users before add new conditional user instructions
540   std::vector<HloInstruction*> old_conditional_users = conditional->users();
541   // Maps instructions in the conditional body to instructions hoisted outside
542   // the conditional that compute the same value.
543   absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
544   // Insert GetTupleElement before the instructions whose operands might still
545   // be within the conditional.
546   VLOG(1) << "before opt:"
547           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
548           << "\n";
549   int64 op_index = 0;
550   for (const Boundary& b : new_boundaries) {
551     HloInstruction* op = b.operands()[0];
552     CHECK(op != nullptr);
553     VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
554     HloInstruction* gtr = conditional_parent->AddInstruction(
555         HloInstruction::CreateGetTupleElement(op->shape(), conditional,
556                                               op_index++));
557     Boundary b2(Boundary::Position::kOutsideBranch);
558     b2.mutable_operands().push_back(gtr);
559     hoisted_instructions[op] = b2;
560   }
561   // Copy boundary instructions out of the conditional.
562   // Visit the operands before its users and copy it, so that the copied
563   // user will point to the correct operand.
564   for (int64 i = to_move_out.size() - 1; i >= 0; i--) {
565     TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(
566         to_move_out[i], 0, conditional_parent, hoisted_instructions));
567   }
568   VLOG(2) << "Done copy branch instructions out\n"
569           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
570           << "\n";
571   // Change original users of the conditional to use the correct operands.
572   HloInstruction* old_root =
573       conditional->branch_computation(0)->root_instruction();
574   for (auto user_instr : old_conditional_users) {
575     VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
576     CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
577     auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
578     int64 index = tuple_opd->tuple_index();
579     CHECK(old_root->operands().size() > index);
580     HloInstruction* old_opd = old_root->operands()[index];
581     VLOG(2) << "old opd = " << old_opd << "\n";
582     CHECK(ContainsKey(hoisted_instructions, old_opd));
583     HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
584     CHECK(old_opd != nullptr);
585     CHECK(new_opd != nullptr);
586     VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n";
587     TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
588     TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
589   }
590   VLOG(2) << "Done changing conditional users\n"
591           << conditional_parent->ToString() << "\n";
592   // Create tuple element within each branch and set it as root.
593   int64 branch_count = conditional->branch_count();
594   for (int i = 0; i < branch_count; i++) {
595     auto computation = conditional->branch_computation(i);
596     std::vector<HloInstruction*> elements;
597     for (const auto& b1 : new_boundaries) {
598       HloInstruction* op = b1.operands()[i];
599       CHECK(op != nullptr);
600       VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
601       elements.push_back(op);
602     }
603     HloInstruction* tuple =
604         computation->AddInstruction(HloInstruction::CreateTuple(elements));
605     computation->set_root_instruction(tuple, true);
606     VLOG(2) << "computation is :" << computation->ToString() << "\n";
607     // Remove hoisted instructions from the branches.
608     for (const auto& b2 : to_move_out) {
609       auto instr_to_remove = b2.operands()[i];
610       // Double check to make sure it is safe to delete the instruction.
611       // Complications may arise due to some operations in the alternative
612       // branches (branches 1..n) being placed into the boundaries multiple
613       // times.
614       if (!computation->IsMarkedAsDead(instr_to_remove) &&
615           instr_to_remove->user_count() == 0) {
616         VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
617         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove));
618       }
619     }
620   }
621   // Change conditional instruction shape to the shape of the new root.
622   HloInstruction* new_root =
623       conditional->branch_computation(0)->root_instruction();
624   *conditional->mutable_shape() = new_root->shape();
625   VLOG(1) << "done moving instructions out of branches\n"
626           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
627           << "\n";
628   return true;
629 }
630 
631 // Hoist ops from outside of the conditional to inside the branches.
MoveInstructionIn(HloInstruction * conditional,std::vector<Boundary> & to_move_in,std::vector<Boundary> & new_boundaries)632 StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
633     HloInstruction* conditional, std::vector<Boundary>& to_move_in,
634     std::vector<Boundary>& new_boundaries) {
635   if (to_move_in.empty()) {
636     return false;
637   }
638   VLOG(1) << "Modifying code---number of boundaries to move in:"
639           << to_move_in.size() << "\n";
640   VLOG(1) << "before opt:"
641           << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
642           << "\n";
643   // Mapping instructions to be moved to their new representations.
644   absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
645   int64 to_move_in_size = to_move_in.size();
646   int64 branch_count = conditional->branch_count();
647   HloGetTupleElementInstruction* tuple_use =
648       DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
649   // If use_index is -1, the old conditional root entry used by to_move_in
650   // instructions still need to be included as an entry of the modified
651   // conditional root, and the new result of the to_move_in instructions
652   // need to be added as an extra entry of the modified root; otherwise, the
653   // old root entry will be replaced with the new result in the modified root.
654   // The entry replacement should be allowed only if tuple_use has <=1 users.
655   int64 use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
656                         ? tuple_use->tuple_index()
657                         : -1;
658   VLOG(2) << "Tuple use index = " << use_index << "\n";
659   // Number of old conditional entries still to be used outside.
660   // If conditional shape is not tuple, will create a tuple and use subscript
661   // 0 to save the old operand being used.
662   int64 op_index =
663       conditional->shape().IsTuple()
664           ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
665                               : conditional->shape().tuple_shapes_size())
666           : 0;
667   // Use to map the tuple_use instruction to its operand;
668   Boundary b_opd_use(Boundary::Position::kInsideBranch);
669   Boundary b_old_root(Boundary::Position::kInsideBranch);
670   // Create a new root instruction in each branch.
671   for (int i = 0; i < branch_count; i++) {
672     auto computation = conditional->branch_computation(i);
673     auto old_root = computation->root_instruction();
674     b_old_root.mutable_operands().push_back(old_root);
675     std::vector<HloInstruction*> operands;
676     if (old_root->opcode() == HloOpcode::kTuple) {
677       // Use operands of old_root directly, so old_root can be removed later.
678       for (int i = 0; i < old_root->operand_count(); ++i) {
679         if (i != use_index) {
680           operands.push_back(old_root->operands()[i]);
681         } else {  // Map conditional use to the tuple operand.
682           b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
683         }
684       }
685     } else if (old_root->shape().IsTuple()) {
686       // If old_root is not a kTuple but has tuple shape, elements within the
687       // tuple must be extracted first to be used by the new instructions.
688       const Shape& old_shape = old_root->shape();
689       for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) {
690         auto element =
691             computation->AddInstruction(HloInstruction::CreateGetTupleElement(
692                 old_shape.tuple_shapes(i), old_root, i));
693         if (i != use_index) {
694           operands.push_back(element);
695         } else {
696           b_opd_use.mutable_operands().push_back(element);
697         }
698       }
699     } else {
700       // If old_root is not a tuple and does not have tuple shape, use it
701       // to replace the conditional directly in the new computation.
702       b_opd_use.mutable_operands().push_back(conditional);
703     }
704 
705     HloInstruction* new_root =
706         computation->AddInstruction(HloInstruction::CreateTuple(operands));
707     VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
708     computation->set_root_instruction(new_root,
709                                       /*accept_different_shape*/ true);
710     if (old_root->opcode() == HloOpcode::kTuple) {
711       TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
712     }
713     VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
714   }
715   // Update get tuple element index of the conditional.
716   if (use_index != -1) {
717     for (auto* user : conditional->users()) {
718       if (user->opcode() == HloOpcode::kGetTupleElement &&
719           user->tuple_index() > use_index) {
720         user->set_tuple_index(user->tuple_index() - 1);
721       }
722     }
723   }
724   hoisted_instructions[conditional] = b_old_root;
725   int64 cp_start = 0;
726   if (use_index >= 0) {
727     VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
728     hoisted_instructions[tuple_use] = b_opd_use;
729   }
730   cp_start = (tuple_use != nullptr) ? 1 : 0;
731   for (int64 to_move_index = cp_start; to_move_index < to_move_in_size;
732        to_move_index++) {
733     Boundary b_to_move = to_move_in[to_move_index];
734     HloInstruction* op = b_to_move.operands()[0];
735     CHECK(op != nullptr);
736     bool to_be_used_outside = true;
737     VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
738     if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
739         op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
740       to_be_used_outside = false;
741       VLOG(2) << "Instruction is not to be used outside the branch\n";
742     }
743     Boundary b(Boundary::Position::kInsideBranch);
744     for (int i = 0; i < branch_count; i++) {
745       auto computation = conditional->branch_computation(i);
746       VLOG(2) << "Copying to branch: " << i << "\n";
747       TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation,
748                                                   hoisted_instructions));
749       VLOG(2) << "Done:" << computation->ToString() << "\n";
750       if (to_be_used_outside) {
751         auto new_op = hoisted_instructions[op].operands()[i];
752         auto new_root = computation->root_instruction();
753         new_root->AppendOperand(new_op);
754         *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape();
755         VLOG(2) << "Extending conditional root " << i << " : "
756                 << new_root->ToString() << "\n";
757       }
758       VLOG(2) << "After extending branch root: " << computation->ToString()
759               << "\n";
760     }
761     if (to_be_used_outside) {
762       // Modify uses of instructions outside of the conditionals
763       HloInstruction* gtr = conditional->parent()->AddInstruction(
764           HloInstruction::CreateGetTupleElement(op->shape(), conditional,
765                                                 op_index++));
766       TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
767       if (conditional->parent()->root_instruction() == op) {
768         conditional->parent()->set_root_instruction(gtr);
769       }
770     }
771   }
772   VLOG(2) << "Done copying instructions inside branch: "
773           << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n";
774   // Change conditional instruction shape to the shape of the new root.
775   HloInstruction* new_root =
776       conditional->branch_computation(0)->root_instruction();
777   *conditional->mutable_shape() = new_root->shape();
778   VLOG(2) << "Before removing instructions:"
779           << conditional->parent()->ToString() << "\n";
780   // Remove hoisted instructions from the branches.
781   for (int64 i = to_move_in_size - 1; i >= 0; i--) {
782     Boundary boundary_to_move_in = to_move_in[i];
783     HloInstruction* op = boundary_to_move_in.operands()[0];
784     if (op->user_count() == 0) {
785       VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
786       TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
787       VLOG(2) << "Done removing boundary.\n";
788     }
789   }
790 
791   // Reset shapes of user gtes to the new shape.
792   if (use_index != -1) {
793     for (auto* user : conditional->users()) {
794       if (user->opcode() == HloOpcode::kGetTupleElement) {
795         VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
796         *user->mutable_shape() =
797             conditional->shape().tuple_shapes(user->tuple_index());
798       }
799     }
800   }
801   VLOG(1) << "Done moving instructions inside branches\n"
802           << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
803           << "\n";
804   return true;
805 }
806 
807 // Group single chains of operands or uses of boundaries into new boundaries
808 class GroupConnectedBoundaries {
809  private:
810   std::vector<Boundary> connected_boundaries_, new_boundaries_;
811   HloInstruction* conditional_;
812   HloComputation* conditional_parent_;
813   bool is_layout_sensitive_;
814   // Instructions that have been visited but are not going to be moved.
815   absl::flat_hash_map<HloInstruction*, int>& visited_count_;
816   // The following four lines are configurations of the cost model, which will
817   // be used to determine whether to move an instruction (move_config_) and how
818   // strongly preferred it is to keep a pair of ops together (reuse_config_).
819   // The search_config_ is used to control how to navigate the search space of
820   // the cost model in the context of auto/manual tuning. The flipped array is
821   // used to save which entries in the configuration have been changed in the
822   // search/tuning process.
823   std::vector<std::vector<int64>>& move_config_;
824   std::vector<std::vector<int64>>& reuse_config_;
825   int& search_config_;
826   absl::flat_hash_map<const int64*, int64> flipped_;
827 
828   // The FlipMutation function serves to implement the search of alternative
829   // cost models by deciding whether to flip a given configuration, saved in
830   // the loc parameter. The non_zero parameter provides the new value to use
831   // to flip a zero. The msg parameter is only used for debugging purpposes.
FlipMutation(int64 * loc,const int64 non_zero,const std::string & msg)832   int64 FlipMutation(int64* loc, const int64 non_zero, const std::string& msg) {
833     if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
834       VLOG(2) << "Configured not to search or loc is already flipped.";
835       return *loc;
836     }
837 
838     // The 8-16 digits control the maximum number of times to flip a config.
839     int flip_count = (search_config_ >> 8) & 255;
840     if (flip_count == 0) {
841       VLOG(2) << "Maximum flip count has reached. ";
842       return *loc;
843     }
844 
845     // The last 8 digits control when to start the first flip.
846     int c = search_config_ & 255;
847     VLOG(2) << "flip start index = " << c << "\n";
848     // Only flip the decision if c reaches 0.
849     if (c > 0) {
850       search_config_--;
851       return *loc;
852     }
853 
854     // Decrement flip count so we can stop if it reaches 0.
855     search_config_ -= 256;
856     // Reload the 16-23 digits of the configuration, which controls how
857     // frequently a configuration should be flipped.
858     search_config_ += (search_config_ >> 16) & 255;
859     VLOG(2) << "Updating Flipping configuration = " << search_config_ << "\n";
860 
861     flipped_[loc] = *loc;
862     // Copy the last 8 bits back to the first 8 bits of configuration.
863     switch (*loc) {
864       case 0:
865         *loc = non_zero;
866         break;
867       default:
868         *loc = 0;
869         break;
870     }
871     VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
872             << " to " << *loc << "\n";
873     return *loc;
874   }
875 
876  public:
GroupConnectedBoundaries(HloInstruction * conditional,bool is_layout_sensitive,absl::flat_hash_map<HloInstruction *,int> & visited_count,std::vector<std::vector<int64>> * move_config,std::vector<std::vector<int64>> * reuse_config,int * search_config)877   explicit GroupConnectedBoundaries(
878       HloInstruction* conditional, bool is_layout_sensitive,
879       absl::flat_hash_map<HloInstruction*, int>& visited_count,
880       std::vector<std::vector<int64>>* move_config,
881       std::vector<std::vector<int64>>* reuse_config, int* search_config)
882       : conditional_(conditional),
883         conditional_parent_(conditional->parent()),
884         is_layout_sensitive_(is_layout_sensitive),
885         visited_count_(visited_count),
886         move_config_(*move_config),
887         reuse_config_(*reuse_config),
888         search_config_(*search_config) {}
889   // Returns estimation of potential reuses carried by a given pair of
890   // instructions. Use different integers to classify different levels
891   // of reuses. Assume all instructions can be fused to enable data reuses.
ReusesCarriedBy(HloInstruction * op,HloInstruction * user)892   int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
893     std::vector<int64>& curconfig =
894         reuse_config_[static_cast<uint32>(op->opcode())];
895     // Flip the reuse configuration if tuning the cost model.
896     // When flipping, use -10 if flipping to the default reuse model. Other
897     // values can be specified if needed to fine-control the decision making.
898     int64 config =
899         (search_config_ < 0)
900             ? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
901                            HloOpcodeString(op->opcode()) + "->" +
902                                HloOpcodeString(user->opcode()))
903             : curconfig[static_cast<uint32>(user->opcode())];
904     VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
905             << op->ToString() << "=>" << user->ToString() << " : " << config
906             << "\n";
907     if (config < 0) {
908       // Assume the reuse decreases with increasing user count.
909       int count1 = CountNonLeafOps(op->users());
910       int count2 = CountNonLeafOps(user->operands());
911       return (-config) / count1 / count2;
912     }
913     return config;
914   }
clear_recently_visited()915   void clear_recently_visited() {
916     for (const auto& boundary : new_boundaries_) {
917       visited_count_.erase(boundary.operands()[0]);
918     }
919   }
920   // Returns true if `instruction` is worth hoisting.
WorthHoisting(HloInstruction * instruction,bool is_inside_branch)921   bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
922     // This is needed for the "moving-in" transformation, to prevent the root
923     // of the parent computation (which contains the conditional) to be moved
924     // inside the conditional.
925     HloOpcode opcode = instruction->opcode();
926     if (opcode == HloOpcode::kTuple &&
927         instruction == conditional_parent_->root_instruction()) {
928       return false;
929     }
930     // It is not safe to move collective ops from outside to inside
931     // conditional branches, as it may cause synchronization problems,
932     // when different layouts are assigned to different branches.
933     if (opcode == HloOpcode::kAllReduce && !is_inside_branch) {
934       return false;
935     }
936 
937     // It is not legal to move the parameter instructions.
938     if (opcode == HloOpcode::kParameter) {
939       return false;
940     }
941 
942     // Use configuration given from outside (e.g., by autotuner).
943     std::vector<int64>& curconfig = move_config_[static_cast<uint32>(opcode)];
944     auto col = (curconfig.size() == 1) ? 0
945                : (instruction->operand_count() > 0)
946                    ? static_cast<uint32>(instruction->operand(0)->opcode())
947                    : 0;
948     VLOG(2) << "column = " << col << "\n";
949     VLOG(2) << "config size = " << curconfig.size() << "\n";
950     VLOG(2) << "search_config = " << search_config_ << "\n";
951     CHECK(col < curconfig.size());
952     uint32 config = (search_config_ > 0)
953                         ? FlipMutation(&curconfig[col], 1,
954                                        "Move-" + HloOpcodeString(opcode))
955                         : curconfig[col];
956     VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
957     return (config != 0);
958   }
959 
ReusesBeforeBoundary(HloInstruction * user)960   int64 ReusesBeforeBoundary(HloInstruction* user) {
961     int64 reuses = 0;
962     for (auto op : user->operands()) {
963       // The operand must be an instruction that is not going to be moved (if
964       // user is inside the conditional); otherwise it must be the conditional
965       // itself and its user must be outside of the conditional.
966       if (!ContainsKey(visited_count_, op) && op != conditional_) {
967         continue;
968       }
969       if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
970         if (op->opcode() == HloOpcode::kConditional) {
971           auto tuple = op->branch_computation(0)->root_instruction();
972           if (tuple->opcode() == HloOpcode::kTuple) {
973             auto index = tuple_gte->tuple_index();
974             CHECK(index < tuple->operand_count());
975             op = tuple->mutable_operand(index);
976           }
977         }
978         reuses += ReusesCarriedBy(op, user->users()[0]);
979       } else {
980         reuses += ReusesCarriedBy(op, user);
981       }
982     }
983     VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
984             << "\n";
985     return reuses;
986   }
987 
ReusesAfterBoundary(HloInstruction * user)988   int64 ReusesAfterBoundary(HloInstruction* user) {
989     CHECK(user != nullptr);
990     auto all_users = user->users();
991     // For now, assume that if an instruction has multiple-consumers, it
992     // will not be reused, as the reuse may require duplication in
993     // fusion and so is expensive. If the situation changes in the future,
994     // some aspects of the overall algorithm need to be redesigned to
995     // accommandate the change.
996     if (all_users.size() > 1) {
997       VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
998       return 0;
999     }
1000     if (!all_users.empty()) {
1001       auto op = all_users[0];
1002       int64 reuses = 0;
1003       // Only count reuses that run through the conditional root.
1004       if (op == conditional_->branch_computation(0)->root_instruction()) {
1005         int64 index = op->operand_index(user);
1006         for (auto op2 : conditional_->users()) {
1007           // If the use is not get tuple, right now do not consider it.
1008           if (op2->opcode() == HloOpcode::kGetTupleElement) {
1009             auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2);
1010             if (index == tuple_opd->tuple_index()) {
1011               all_users = op2->users();
1012               if (!all_users.empty()) {
1013                 reuses += ReusesCarriedBy(user, all_users[0]);
1014                 break;
1015               }
1016             }
1017           }
1018         }
1019       } else if (ContainsKey(visited_count_, op)) {
1020         reuses += ReusesCarriedBy(user, op);
1021       }
1022       VLOG(2) << "reuses after instruction " << user->ToString() << ":"
1023               << reuses << "\n";
1024       return reuses;
1025     }
1026     return 0;
1027   }
1028 
BenefitForMovingBoundaries(const std::vector<Boundary> & boundaries)1029   int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries) {
1030     int64 reuses_before = 0, reuses_after = 0;
1031     if (boundaries.size() == 1) {
1032       if (boundaries[0].IsOutsideBranch() &&
1033           boundaries[0].operands()[0]->opcode() ==
1034               HloOpcode::kGetTupleElement) {
1035         // The only boundary of moving-in is the get_tuple_element op.
1036         return -1;
1037       }
1038       if (boundaries[0].IsInsideBranch() &&
1039           boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
1040         // The only boundary of moving-out is the tuple op inside branches.
1041         return -1;
1042       }
1043     }
1044     // If trying alternative moving configurations, turn off reuse analysis.
1045     if (search_config_ > 0) {
1046       return 1;
1047     }
1048     // For cases like :
1049     // branch0 {
1050     //   ROOT copy
1051     // }
1052     // branch1 {
1053     //   ...
1054     // }
1055     // cond = conditional(branch0, branch1)
1056     // copy = copy(cond)
1057     //
1058     // We can fold the two copies thus reducing computation.
1059     auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64 {
1060       if (hlo->opcode() != HloOpcode::kCopy) {
1061         return 0;
1062       }
1063       const HloGetTupleElementInstruction* gte =
1064           DynCast<HloGetTupleElementInstruction>(hlo->operand(0));
1065       if (gte == nullptr) {
1066         return 0;
1067       }
1068       const HloInstruction* conditional = gte->operand(0);
1069       if (conditional != conditional_) {
1070         return 0;
1071       }
1072       int64 benefit = 0;
1073       for (auto* branch : conditional->called_computations()) {
1074         HloInstruction* root = branch->root_instruction();
1075         if (root->opcode() == HloOpcode::kTuple) {
1076           const auto* tuple_operand = root->operand(gte->tuple_index());
1077           if (tuple_operand->opcode() == HloOpcode::kCopy) {
1078             if (Shape::Equal()(tuple_operand->operand(0)->shape(),
1079                                hlo->shape())) {
1080               benefit += 10;
1081             }
1082           }
1083         }
1084       }
1085       return benefit;
1086     };
1087     for (const Boundary& b : boundaries) {
1088       auto op = b.operands()[0];
1089       if (op == conditional_->branch_computation(0)->root_instruction()) {
1090         continue;
1091       }
1092       VLOG(2) << "Benefit for " << op->ToString();
1093       reuses_before += ReusesBeforeBoundary(op);
1094       VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
1095       reuses_after += ReusesAfterBoundary(op);
1096       VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
1097     }
1098 
1099     int64 copy_folding_benefit = 0;
1100     if (boundaries[0].IsOutsideBranch()) {
1101       for (const Boundary& b : boundaries) {
1102         auto op = b.operands()[0];
1103         copy_folding_benefit += get_copy_folding_benefit(op);
1104       }
1105     }
1106     VLOG(2) << "Copy folding benefit: " << copy_folding_benefit;
1107 
1108     if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) {
1109       return -1;
1110     } else if (boundaries[0].IsInsideBranch()) {
1111       return reuses_after - reuses_before;
1112     } else {
1113       return reuses_before - reuses_after - 1 + copy_folding_benefit;
1114     }
1115   }
1116 
GetNextBoundary(const Boundary & b,int64 op_index)1117   Boundary GetNextBoundary(const Boundary& b, int64 op_index) {
1118     Boundary b2(b.GetPosition());
1119     for (int j = 0; j < b.operands().size(); ++j) {
1120       HloInstruction* inst = b.operands()[j];
1121       CHECK(inst != nullptr);
1122       HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
1123                                                 : inst->users()[op_index];
1124       CHECK(op != nullptr);
1125       b2.mutable_operands().push_back(op);
1126     }
1127     return b2;
1128   }
1129 
1130   // Checking whether it is safe to move a boundary when visited through a
1131   // dependent already considered for moving.
IsSafeToMoveBoundary(const Boundary & next_boundary)1132   bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
1133     int64 next_boundary_count =
1134         (next_boundary.IsInsideBranch())
1135             ? next_boundary.operands()[0]->user_count()
1136             : CountNonLeafOps(next_boundary.operands()[0]->operands());
1137     if (next_boundary_count <= 1) {
1138       // If boundary has only a single or no dependent, safe to move.
1139       return true;
1140     } else {
1141       if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
1142         VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
1143                 << " because it has multiple dependents: "
1144                 << next_boundary_count << "\n";
1145         visited_count_[next_boundary.operands()[0]] = 1;
1146         new_boundaries_.push_back(next_boundary);
1147       } else {
1148         auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
1149                              next_boundary);
1150         if (pos != new_boundaries_.end() ||
1151             next_boundary.operands().size() == 1) {
1152           int count = ++visited_count_[next_boundary.operands()[0]];
1153           if (count == next_boundary_count) {
1154             VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
1155                     << "\n"
1156                     << " because all of its dependents have been visited: "
1157                     << next_boundary_count << "\n";
1158             visited_count_.erase(next_boundary.operands()[0]);
1159             if (pos != new_boundaries_.end()) {
1160               new_boundaries_.erase(pos);
1161             }
1162             return true;
1163           }
1164         } else {
1165           VLOG(2) << "Skip incompatible multi-dependent boundary: "
1166                   << next_boundary.ToString() << ":" << next_boundary_count
1167                   << "\n";
1168         }
1169       }
1170     }
1171     return false;
1172   }
1173   // This function is reused both for moving the boundary outside or into a
1174   // conditional. As the result, the readability is somewhat compromised.
1175   // It might be nice to refactor this function to factor the outside-inside
1176   // considerations into separate function pointer parameters to improve
1177   // readability.
AddBoundaries(const Boundary & boundary)1178   void AddBoundaries(const Boundary& boundary) {
1179     BoundaryVisitor visitor;
1180     visitor.AddToWorkList(boundary);
1181     while (visitor.HasNextBoundary()) {
1182       Boundary b = visitor.PopNextBoundary();
1183       VLOG(2) << "visiting boundary " << b.ToString() << "\n";
1184       if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
1185                                       b.operands(), is_layout_sensitive_)) &&
1186           IsSafeToMoveBoundary(b) &&
1187           WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
1188         connected_boundaries_.push_back(b);
1189         VLOG(2) << "boundary can be moved\n";
1190         int64 operand_count = (b.IsInsideBranch())
1191                                   ? b.operands()[0]->operand_count()
1192                                   : b.operands()[0]->users().size();
1193         for (int i = 0; i < operand_count; i++) {
1194           Boundary next_boundary = GetNextBoundary(b, i);
1195           VLOG(2) << "Add operand/user " << i << " to visit later\n";
1196           visitor.AddToWorkList(next_boundary);
1197         }
1198       } else {
1199         VLOG(2) << "boundary cannot be moved\n";
1200         visited_count_[b.operands()[0]] = 1;
1201         new_boundaries_.push_back(b);
1202       }
1203     }
1204   }
BoundariesToMoveInOrOut(HloInstruction * conditional,const Boundary & b)1205   std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
1206                                                 const Boundary& b) {
1207     // At the beginning of optimization, a conditional itself is added to a
1208     // worklist. Here the conditional is expanded into two sets of boundaries:
1209     // the first set contains the boundary that is inside branches and
1210     // contains the root of all branches; the second set of boundaries
1211     // contains all the users of the conditional.
1212     HloInstruction* inst = b.operands()[0];
1213     if (inst == conditional) {
1214       int branch_count = inst->branch_count();
1215       // Add conditional roots as a new boundary to visit.
1216       Boundary boundary_in(Boundary::Position::kInsideBranch);
1217       for (int i = 0; i < branch_count; i++) {
1218         HloComputation* branch_computation = inst->branch_computation(i);
1219         HloInstruction* root_inst = branch_computation->root_instruction();
1220         CHECK(root_inst != nullptr);
1221         boundary_in.mutable_operands().push_back(root_inst);
1222       }
1223       new_boundaries_.push_back(boundary_in);
1224       // Add conditional users as new boundaries to visit.
1225       for (auto u : inst->users()) {
1226         Boundary boundary_in(Boundary::Position::kOutsideBranch);
1227         boundary_in.mutable_operands().push_back(u);
1228         new_boundaries_.push_back(boundary_in);
1229       }
1230     } else {
1231       AddBoundaries(b);
1232     }
1233     return connected_boundaries_;
1234   }
AddNewBoundaries(std::vector<Boundary> & b)1235   void AddNewBoundaries(std::vector<Boundary>& b) {
1236     b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end());
1237   }
1238 };
1239 
ConsiderCodeMotion(HloInstruction * conditional,const Boundary & cur_boundary,std::vector<Boundary> & to_move,std::vector<Boundary> & new_boundaries,absl::flat_hash_map<HloInstruction *,int> & visited_count)1240 ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
1241     HloInstruction* conditional, const Boundary& cur_boundary,
1242     std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
1243     absl::flat_hash_map<HloInstruction*, int>& visited_count) {
1244   GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
1245                                    visited_count, &move_config_, &reuse_config_,
1246                                    &search_config_);
1247   auto move_in_or_out =
1248       connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
1249   if (!move_in_or_out.empty()) {
1250     auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
1251     VLOG(2) << "benefit of moving in or out "
1252             << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
1253     if (benefit >= 0) {
1254       new_boundaries.clear();
1255       connect.AddNewBoundaries(new_boundaries);
1256       // The whole sequence in move_in_or_out is either all moving into a
1257       // conditional, or all moving out of a conditional. So looking only
1258       // at the first entry of the sequence is sufficient to know which
1259       // direction the move is intended.
1260       to_move = move_in_or_out;
1261       return Decision(to_move[0].IsInsideBranch()
1262                           ? Decision::Direction::kMoveOutOfBranch
1263                           : Decision::Direction::kMoveIntoBranch,
1264                       benefit);
1265     } else {
1266       connect.clear_recently_visited();
1267     }
1268   } else {
1269     connect.AddNewBoundaries(new_boundaries);
1270   }
1271   return Decision(Decision::Direction::kNoChange, 0);
1272 }
1273 
Run(HloModule * module)1274 StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
1275   VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
1276   // Use to support debugging of optimization, by disabling the opt after it has
1277   // been applied a pre-determined times (to isolate impact of transformations).
1278   if (!ConsumeFuel("conditional_code_motion", [&] {
1279         return "Skipping conditional opt after allowed limit reaching 0.\n";
1280       })) {
1281     return false;
1282   }
1283   bool changed = false;
1284   bool cleanup_changed = false;
1285   {
1286     HloPassPipeline subpipeline("before_conditional_code_motion");
1287     subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_);
1288     subpipeline.AddPass<HloDCE>();
1289     TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1290     cleanup_changed |= cleanup_changed_now;
1291   }
1292   // set the default configuration
1293   VLOG(2) << "Obtaining default configuration\n";
1294   SetDefaultMoveConfig();
1295   VLOG(2) << "Done obtaining default configuration\n";
1296   // Gather all the conditional ops in the module ahead of time, to avoid
1297   // potential complications of modifying the code that affecting traversal.
1298   std::vector<HloInstruction*> conditional_ops;
1299   // Track how many times each branch computation is shared.
1300   absl::flat_hash_map<HloComputation*, int> conditional_computations;
1301   for (auto* comp : module->MakeComputationPostOrder()) {
1302     for (auto* instr : comp->MakeInstructionPostOrder()) {
1303       if (instr->opcode() == HloOpcode::kConditional) {
1304         int branch_count = instr->branch_count();
1305         for (int i = 0; i < branch_count; ++i) {
1306           HloComputation* branch_i = instr->branch_computation(i);
1307           if (ContainsKey(conditional_computations, branch_i)) {
1308             conditional_computations[branch_i]++;
1309           } else {
1310             conditional_computations[branch_i] = 0;
1311           }
1312         }
1313         if (instr->shape().IsTuple()) {
1314           bool can_change_tuple_shape = true;
1315           for (auto user : instr->users()) {
1316             VLOG(2) << "user is : " << user->ToString() << "\n";
1317             if (user->opcode() != HloOpcode::kGetTupleElement) {
1318               can_change_tuple_shape = false;
1319             }
1320           }
1321           if (can_change_tuple_shape) {
1322             conditional_ops.push_back(instr);
1323           }
1324         } else {
1325           conditional_ops.push_back(instr);
1326         }
1327       }
1328     }
1329   }
1330 
1331   // Use to collect mappings between cloned instructions.
1332   HloCloneContext clone_context(module);
1333   for (HloInstruction* conditional : conditional_ops) {
1334     int branch_count = conditional->branch_count();
1335     // check for shared conditional computations
1336     bool conditional_is_shared = false;
1337     for (int i = 0; i < branch_count; ++i) {
1338       HloComputation* branch_i = conditional->branch_computation(i);
1339       if (conditional_computations[branch_i] > 0) {
1340         conditional_is_shared = true;
1341         break;
1342       }
1343     }
1344 
1345     // Boundaries to move out or to move into the branches.
1346     std::vector<std::vector<Boundary> > to_move_out, to_move_in;
1347     std::vector<std::vector<Boundary> > new_boundaries_for_moveout;
1348     std::vector<std::vector<Boundary> > new_boundaries_for_movein;
1349     // Number of times each instruction has been visited for moving.
1350     absl::flat_hash_map<HloInstruction*, int> visited_count;
1351     int benefit_move_out = 0, benefit_move_in = 0;
1352     Decision::Direction final_d = Decision::Direction::kNoChange;
1353     // The conditional is moved into a worklist as the seed (starting point).
1354     // The conditional will be expanded into multiple seeds (starting points),
1355     // its roots and its users, when it is visited by GroupConnectedBoundaries.
1356     // A NO_CHANGE decision will always be returned for the conditional itself,
1357     // so that the other seeding boundaries can be visited in turn.
1358     BoundaryVisitor visitor(conditional);
1359     VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
1360     // Try visit all the boundaries, collect the analysis results, and save
1361     // all the benefitical non-conflicting decisions. If two decisions conflict
1362     // with each other, save the more benefitical one.
1363     while (visitor.HasNextBoundary()) {
1364       std::vector<Boundary> to_move, next_boundary;
1365       Boundary boundary = visitor.PopNextBoundary();
1366       VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
1367       auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
1368                                   visited_count);
1369       switch (d.GetDirection()) {
1370         case Decision::Direction::kMoveOutOfBranch:
1371           VLOG(2) << "Local Decision is move out of branch\n";
1372           to_move_out.push_back(to_move);
1373           new_boundaries_for_moveout.push_back(next_boundary);
1374           benefit_move_out += d.GetBenefit();
1375           if (benefit_move_out >= benefit_move_in) {
1376             final_d = Decision::Direction::kMoveOutOfBranch;
1377             VLOG(2) << "Current Decision is move out of branch ("
1378                     << to_move_out.size() << ")\n";
1379           } else {
1380             VLOG(2) << "Current Decision remains move into branch\n";
1381           }
1382           break;
1383         case Decision::Direction::kMoveIntoBranch:
1384           VLOG(2) << "Decision is move into branch\n";
1385           to_move_in.push_back(to_move);
1386           new_boundaries_for_movein.push_back(next_boundary);
1387           benefit_move_in += d.GetBenefit();
1388           if (benefit_move_out >= benefit_move_in) {
1389             VLOG(2) << "Current Decision remains move out of branch\n";
1390           } else {
1391             final_d = Decision::Direction::kMoveIntoBranch;
1392             VLOG(2) << "Current Decision is move into branch ("
1393                     << to_move_in.size() << ")\n";
1394           }
1395           break;
1396         case Decision::Direction::kNoChange:
1397           VLOG(2) << "Decision is no change\n";
1398           for (const Boundary& b : next_boundary) {
1399             visitor.AddToWorkList(b);
1400             VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
1401                     << "\n";
1402           }
1403           break;
1404       }
1405     }
1406     // If modification is to be made, need to clone the shared branches.
1407     if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
1408       for (int i = 0; i < branch_count; ++i) {
1409         HloComputation* branch_i = conditional->branch_computation(i);
1410         if (conditional_computations[branch_i] > 0) {
1411           // Cloning is absolutely needed if the computation is shared by
1412           // different branches, but the cloning can be potentially avoided
1413           // if the sharing is only among branches of the same conditional.
1414           // If cloning these branches causes a problem due to space issues,
1415           // a fix can pass a vector of unique branches to the actual
1416           // transformations, as an alternative representation of the
1417           // conditional branches to be modified. Right now we assume the
1418           // overhead of cloning is minimal since later stages of the compiler
1419           // inline all the computations anyway.
1420           HloComputation* clone_i =
1421               conditional->parent()->parent()->AddEmbeddedComputation(
1422                   branch_i->Clone("clone", &clone_context));
1423           conditional->set_branch_computation(i, clone_i);
1424           conditional_computations[branch_i]--;
1425           // Need to translate the analysis result to generate correct result.
1426           auto update_boundary = [&](Boundary& boundary) {
1427             auto cloned_instr =
1428                 clone_context.FindInstruction(boundary.operands()[i]);
1429             CHECK(cloned_instr != nullptr);
1430             VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
1431                     << "\n";
1432             boundary.mutable_operands()[i] = cloned_instr;
1433             VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
1434                     << "\n";
1435           };
1436           // Only boundaries to move out need to be updated.
1437           if (final_d == Decision::Direction::kMoveOutOfBranch) {
1438             for (int i = 0; i < to_move_out.size(); ++i) {
1439               std::vector<Boundary>& m = to_move_out[i];
1440               std::for_each(m.begin(), m.end(), update_boundary);
1441             }
1442             for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
1443               std::vector<Boundary>& m = new_boundaries_for_moveout[i];
1444               std::for_each(m.begin(), m.end(), update_boundary);
1445             }
1446           }
1447         }
1448       }
1449       VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
1450               << "\n";
1451     }
1452     // At most one of to_move_out or to_move_in can be non-empty, since there is
1453     // only one optimization decision.
1454     if (final_d == Decision::Direction::kMoveOutOfBranch) {
1455       CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
1456       for (int i = 0; i < to_move_out.size(); ++i) {
1457         TF_ASSIGN_OR_RETURN(bool result,
1458                             MoveInstructionOut(conditional, to_move_out[i],
1459                                                new_boundaries_for_moveout[i]));
1460         changed |= result;
1461       }
1462       VLOG(2) << "Done moving out of branches " << to_move_out.size()
1463               << " times. \n";
1464       if (!ConsumeFuel("conditional_code_motion", [&] {
1465             return "Skipping conditional opt after allowed limit reaching 0.\n";
1466           })) {
1467         break;
1468       }
1469     } else if (final_d == Decision::Direction::kMoveIntoBranch) {
1470       CHECK(to_move_in.size() == new_boundaries_for_movein.size());
1471       for (int i = 0; i < to_move_in.size(); ++i) {
1472         TF_ASSIGN_OR_RETURN(bool result,
1473                             MoveInstructionIn(conditional, to_move_in[i],
1474                                               new_boundaries_for_movein[i]));
1475         changed |= result;
1476       }
1477       VLOG(2) << "Done moving into branches " << to_move_in.size()
1478               << " times. \n";
1479       if (!ConsumeFuel("conditional_code_motion", [&] {
1480             return "Skipping conditional opt after allowed limit reaching 0.\n";
1481           })) {
1482         break;
1483       }
1484     } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
1485       // Invoke special handling for convert rematerialization/hoisting
1486       // We need to make sure no sharing is present in the branches because no
1487       // cloning has been done by the earlier analysis.
1488       // TOOD[b/165848866]: extend solution to handle cloning for special move.
1489       TF_ASSIGN_OR_RETURN(
1490           bool convert_result,
1491           ConvertSpecialMove(conditional, is_layout_sensitive_));
1492       if (convert_result) {
1493         VLOG(2) << "Done special moving of convert\n";
1494         if (!ConsumeFuel("conditional_code_motion", [&] {
1495               return "Skipping conditional opt after allowed limit reaching "
1496                      "0.\n";
1497             })) {
1498           break;
1499         }
1500       }
1501       changed |= convert_result;
1502     }
1503   }
1504   if (changed) {
1505     HloPassPipeline subpipeline(
1506         "after_conditional_code_motion_after_convert_hoisting");
1507     VLOG(2) << "starting after motion passes: DCE\n";
1508     subpipeline.AddPass<HloDCE>();
1509     subpipeline.AddPass<TupleSimplifier>();
1510     subpipeline.AddPass<HloDCE>();
1511     TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1512     cleanup_changed |= cleanup_changed_now;
1513   }
1514   if (cleanup_changed) {
1515     VLOG(2) << "subpipeline cleanup have modified code\n";
1516   }
1517   return changed;
1518 }
1519 
SetDefaultMoveConfig()1520 void ConditionalCodeMotion::SetDefaultMoveConfig() {
1521   int tuning_option = (search_config_ == 0) ? 0 : (search_config_ > 0) ? 1 : 2;
1522 
1523   auto row = HloOpcodeCount();
1524   auto col = row;
1525   VLOG(2) << "Start setting default configuration\n";
1526   reuse_config_.reserve(row);
1527   move_config_.reserve(row);
1528   for (int64 opcode = 0; opcode < row; ++opcode) {
1529     // To save whether an instruction is preferred to be moved.
1530     std::vector<int64> reuse_vec(col, 0);
1531     for (uint32 j = 0; j < col; ++j) {
1532       reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
1533                                      static_cast<HloOpcode>(j));
1534     }
1535     reuse_config_.push_back(reuse_vec);
1536     std::vector<int64> move_vec;
1537     switch (tuning_option) {
1538       case 1:
1539         // Tuning transformation decision --- start with all yes.
1540         // Only a single entry is needed if we don't consider operands of an op
1541         // when searching/tuning transformation decisions.
1542         move_vec.push_back(1);
1543         break;
1544       case 2:  // Tune the ReusesCarriedBy results only.
1545       case 0:
1546         // No tuning --- use the default configuration.
1547         // Use the opcode of first operand to configure default.
1548         move_vec.reserve(col);
1549         for (uint32 j = 0; j < col; ++j) {
1550           move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
1551                                            static_cast<HloOpcode>(j)));
1552         }
1553         break;
1554     }
1555     move_config_.push_back(move_vec);
1556   }
1557 }
1558 
1559 }  // namespace conditional_opt
1560 
1561 }  // namespace xla
1562