• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/conditional_code_motion.h"
17 
18 #include <iterator>
19 #include <stack>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/service/call_graph.h"
31 #include "tensorflow/compiler/xla/service/call_inliner.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_cse.h"
35 #include "tensorflow/compiler/xla/service/hlo_dce.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
40 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
41 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/platform/errors.h"
49 
50 namespace xla {
51 
52 namespace conditional_opt {
53 
54 class BoundaryVisitor {
55  public:
56   // start with an existing conditional computation.
BoundaryVisitor(HloInstruction * conditional)57   explicit BoundaryVisitor(HloInstruction* conditional) {
58     Boundary b(Boundary::Position::kInsideBranch);
59     b.mutable_operands().push_back(conditional);
60     worklist_.push_back(b);
61   }
62   // Start with an empty work list.
BoundaryVisitor()63   BoundaryVisitor() {}
64   // Get next boundary to visit.
PopNextBoundary()65   Boundary PopNextBoundary() {
66     CHECK(!worklist_.empty());
67     Boundary b = worklist_.front();
68     worklist_.pop_front();
69     // if b is already visited, it must have multiple users and is already in
70     // new boundaries. Skip it. Only checking the first operand of b because b
71     // is expected to have at least one operand, and all the operands in b
72     // must be identical instructions from different branches for b to be moved.
73     while (!worklist_.empty() && ContainsKey(visited_, b.operands()[0])) {
74       b = worklist_.front();
75       worklist_.pop_front();
76     }
77     visited_.insert(b.operands()[0]);
78     return b;
79   }
AddToWorkList(const Boundary & b)80   void AddToWorkList(const Boundary& b) {
81     CHECK(!b.operands().empty());
82     worklist_.push_back(b);
83   }
84 
HasNextBoundary()85   bool HasNextBoundary() {
86     while (!worklist_.empty()) {
87       Boundary b = worklist_.front();
88       if (!ContainsKey(visited_, b.operands()[0])) {
89         break;
90       }
91       worklist_.pop_front();
92     }
93     return !worklist_.empty();
94   }
95 
96  private:
97   // worklist is the deque that contains instructions to be visited.
98   std::deque<Boundary> worklist_;
99   absl::flat_hash_set<HloInstruction*> visited_;
100 };
101 
102 template <class OpCollection>
CountNonLeafOps(const OpCollection & ops)103 int64 CountNonLeafOps(const OpCollection& ops) {
104   absl::flat_hash_set<HloInstruction*> op_set;
105   for (auto op : ops) {
106     if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) {
107       op_set.insert(op);
108     }
109   }
110   return op_set.size();
111 }
112 
113 // Returns estimation of potential reuses carried by a given pair of
114 // instructions.  Use different integers to classify different levels
115 // of reuses This is used as a placeholder only, assuming all
116 // instructions can be fused to enable data reuses
ReusesCarriedBy(HloOpcode op,HloOpcode user)117 int64 ReusesCarriedBy(HloOpcode op, HloOpcode user) {
118   // Reuses in some way work like forces that pull instructions
119   // towards each other. We use a number 0-10 to classify how strong the force
120   // is between a pair of operations. Given a group of instructions that can be
121   // moved together, if the forces inside a conditional are stronger, the group
122   // will be moved incide or remain inside the conditional; otherwise, it will
123   // be moved outside to or remain outside of the conditional.
124   switch (user) {
125     case HloOpcode::kGetTupleElement:
126       return 0;
127     case HloOpcode::kConvert:
128       // Because convert is treated not moveable when following Dot or
129       // convolution, here if op is dot or convolution, they must be separated
130       // by a conditional boundary. Here we do not try to pull convert inside
131       // conditionals to be together with the dot or convolution.
132       switch (op) {
133         case HloOpcode::kConvolution:
134         case HloOpcode::kDot:
135           return 0;
136         default:
137           break;
138       }
139       break;
140     default:
141       break;
142   }
143   switch (op) {
144       // These instructions do not carry weight of reuse themselves.
145     case HloOpcode::kParameter:
146     case HloOpcode::kConstant:
147     case HloOpcode::kGetTupleElement:
148       return 0;
149     case HloOpcode::kConditional:
150       return 10;
151     default:
152       return -10;
153   }
154 }
155 
156 // Returns true if `op` is worth hoisting.
WorthHoisting(HloOpcode op,HloOpcode child_op)157 bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
158   // TOOD[b/169182921] The following cost model is rather incomplete. Will
159   // need to extend to cover most of element-wise ops.
160   switch (op) {
161     case HloOpcode::kConvert:
162       // If Convert is after AllReduce, it is worth moving out AllReduce
163       // out of conditional for AR/CRS combine. If Convert is after other
164       // ops such as Dot or Convolutional, it is better to keep convert
165       // within conditional so that convert can be fused with Dot or
166       // Convolutional.
167       switch (child_op) {
168         case HloOpcode::kAllReduce:
169         case HloOpcode::kReshape:
170         case HloOpcode::kGetTupleElement:
171           return true;
172         default:
173           return false;
174       }
175     case HloOpcode::kGetTupleElement:
176       switch (child_op) {
177         // do not move GTE if its operand is a parameter
178         case HloOpcode::kParameter:
179           return false;
180         default:
181           return true;
182       }
183     case HloOpcode::kAllReduce:
184     case HloOpcode::kReduceScatter:
185     case HloOpcode::kAbs:
186     case HloOpcode::kReduce:
187     case HloOpcode::kAdd:
188     case HloOpcode::kPower:
189     case HloOpcode::kCopy:
190     case HloOpcode::kConstant:
191     case HloOpcode::kSubtract:
192     case HloOpcode::kMultiply:
193     case HloOpcode::kDivide:
194     case HloOpcode::kTuple:
195     case HloOpcode::kSqrt:
196     case HloOpcode::kRsqrt:
197     case HloOpcode::kReshape:
198     case HloOpcode::kMinimum:
199     case HloOpcode::kMaximum:
200       return true;
201     default:
202       return false;
203   }
204 }
205 
206 // Compare if the instructions to be visited at each branches are identical.
InstructionWithinBranchIdentical(const std::vector<HloInstruction * > & instructions,bool is_layout_sensitive)207 bool InstructionWithinBranchIdentical(
208     const std::vector<HloInstruction*>& instructions,
209     bool is_layout_sensitive) {
210   // Identical includes the shape of each operands are equal.
211   auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
212     bool eq_operands = is_layout_sensitive
213                            ? ShapeUtil::Equal(a->shape(), b->shape())
214                            : ShapeUtil::Compatible(a->shape(), b->shape());
215     return eq_operands;
216   };
217 
218   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
219     return *a == *b;
220   };
221 
222   if (instructions.empty()) {
223     return false;
224   }
225 
226   if (instructions[0]->IsCrossModuleAllReduce()) {
227     return std::all_of(
228         instructions.begin(), instructions.end(),
229         [&](HloInstruction* instruction) {
230           if (!instruction->IsCrossModuleAllReduce()) {
231             return false;
232           }
233           auto old_channel_id = instruction->channel_id();
234           instruction->set_channel_id(instructions[0]->channel_id());
235           bool eq_instructions = instructions[0]->Identical(
236               *instruction, eq_operand, eq_computations, is_layout_sensitive);
237           instruction->set_channel_id(old_channel_id);
238           return eq_instructions;
239         });
240   }
241 
242   return std::all_of(instructions.begin(), instructions.end(),
243                      [&](HloInstruction* instruction) {
244                        return instructions[0]->Identical(
245                            *instruction, eq_operand, eq_computations,
246                            is_layout_sensitive);
247                      });
248 }
249 
250 // Copy the ith instruction in boundary to outside of conditional, or do the
251 // opposite (for moving in).
CopyInOrOutOfConditional(Boundary & boundary,int64_t dest_index,HloComputation * parent,absl::flat_hash_map<HloInstruction *,Boundary> & hoisted_instructions)252 Status CopyInOrOutOfConditional(
253     Boundary& boundary, int64_t dest_index, HloComputation* parent,
254     absl::flat_hash_map<HloInstruction*, Boundary>& hoisted_instructions) {
255   CHECK(dest_index == 0 || boundary.IsOutsideBranch());
256   HloInstruction* op = boundary.operands()[0];
257   absl::InlinedVector<HloInstruction*, 4> new_operands;
258   for (int i = 0; i < op->operands().size(); ++i) {
259     auto op_i = op->operands()[i];
260     VLOG(2) << "Looking for " << op_i->ToString() << "\n";
261     if (ContainsKey(hoisted_instructions, op_i)) {
262       auto new_op_i =
263           FindOrDie(hoisted_instructions, op_i).operands()[dest_index];
264       VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
265       new_operands.push_back(new_op_i);
266     } else {
267       switch (op_i->opcode()) {
268         case HloOpcode::kConstant: {
269           auto new_op_i = parent->AddInstruction(op_i->Clone());
270           VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
271           new_operands.push_back(new_op_i);
272           break;
273         }
274         case HloOpcode::kGetTupleElement: {
275           auto gte = Cast<HloGetTupleElementInstruction>(op_i);
276           int64_t index = gte->tuple_index();
277           HloInstruction* root = parent->root_instruction();
278           CHECK(root->opcode() == HloOpcode::kTuple &&
279                 index < root->operand_count());
280           auto new_op_i = root->mutable_operand(index);
281           VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
282           new_operands.push_back(new_op_i);
283           break;
284         }
285         default:
286           LOG(FATAL) << "Unexpected out-of-boundary instruction:"
287                      << op_i->ToString() << "\n";
288       }
289     }
290   }
291   HloInstruction* new_instruction = parent->AddInstruction(
292       op->CloneWithNewOperands(op->shape(), new_operands));
293   VLOG(2) << "new instruction:" << new_instruction->ToString() << "\n";
294   // Maps the instruction outside of conditional to the instruction
295   // inside of the conditional.
296   for (HloInstruction* op : boundary.operands()) {
297     Boundary b2 = ContainsKey(hoisted_instructions, op)
298                       ? hoisted_instructions[op]
299                       : Boundary(boundary.IsOutsideBranch()
300                                      ? Boundary::Position::kInsideBranch
301                                      : Boundary::Position::kOutsideBranch);
302     b2.mutable_operands().push_back(new_instruction);
303     hoisted_instructions[op] = b2;
304   }
305   return Status::OK();
306 }
307 
308 // Identify converts to be hoisted/rematerialized out of the branch
309 // computations.
FindSpecialConverts(HloInstruction * old_root,int branch_count,HloInstruction * conditional,bool is_layout_sensitive)310 absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
311                                                int branch_count,
312                                                HloInstruction* conditional,
313                                                bool is_layout_sensitive) {
314   absl::flat_hash_set<int64> kspecial_convert;
315   for (int64_t operand_num = 0; operand_num < old_root->operand_count();
316        ++operand_num) {
317     if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
318       continue;
319     }
320     bool replica = true;
321     HloInstruction* kspecial_convert_candidate =
322         old_root->mutable_operand(operand_num);
323     // Check whether an identical candidate appears in other branches
324     for (int others = 1; others < branch_count; ++others) {
325       HloInstruction* others_root =
326           conditional->branch_computation(others)->root_instruction();
327       bool eq_shape =
328           is_layout_sensitive
329               ? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
330                                  kspecial_convert_candidate->shape())
331               : ShapeUtil::Compatible(
332                     others_root->operand(operand_num)->shape(),
333                     kspecial_convert_candidate->shape());
334       if ((others_root->operand(operand_num)->opcode() ==
335            HloOpcode::kConvert) &&
336           eq_shape) {
337         // Nothing to be done.
338       } else {
339         replica = false;
340         break;
341       }
342     }
343     if (replica) {
344       kspecial_convert.insert(operand_num);
345     }
346   }
347   return kspecial_convert;
348 }
349 
350 // Restructuring the conditional instruction as follows:
351 // i.e., %result = conditional() becomes
352 // x = conditional()
353 // y.{0..n} = gte(x, {0..n})
354 // z = tuple(y.0, y.1, ...y.n)
355 // Doing so ensures that we can accommodate the possible shape-change of the
356 // conditional when the instructions are hoisted.
RestructureConditionalInstruction(HloComputation * computation,HloInstruction * conditional)357 Status RestructureConditionalInstruction(HloComputation* computation,
358                                          HloInstruction* conditional) {
359   HloInstruction* old_root = computation->root_instruction();
360   std::vector<HloInstruction*> new_operands;
361   int cur_index = 0;
362   for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
363        ++cur_index) {
364     new_operands.push_back(
365         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
366             ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
367             conditional, cur_index)));
368   }
369   HloInstruction* new_tuple =
370       computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
371   if (old_root == conditional) {
372     computation->set_root_instruction(new_tuple);
373   } else {
374     std::vector<HloInstruction*> new_tuple_users;
375     for (auto conditional_user : conditional->users()) {
376       auto is_new_gte = absl::c_find_if(
377           new_operands,
378           [&](HloInstruction* instr) { return instr == conditional_user; });
379       if (is_new_gte == new_operands.end()) {
380         new_tuple_users.push_back(conditional_user);
381       }
382     }
383     for (auto new_tuple_user : new_tuple_users) {
384       TF_RETURN_IF_ERROR(
385           conditional->ReplaceUseWith(new_tuple_user, new_tuple));
386     }
387   }
388   VLOG(2) << "computation after root restructure:\n" << computation->ToString();
389   return Status::OK();
390 }
391 
ConvertSpecialMove(HloInstruction * conditional,bool is_layout_sensitive)392 StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
393                                   bool is_layout_sensitive) {
394   int branch_count = conditional->branch_count();
395   if (branch_count <= 0) {
396     return false;
397   }
398 
399   // Determining whether all branch roots are tuples
400   for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
401     HloInstruction* branch_root =
402         conditional->branch_computation(branch_num)->root_instruction();
403     if (branch_root->opcode() != HloOpcode::kTuple) {
404       return false;
405     }
406   }
407 
408   HloInstruction* old_root =
409       conditional->branch_computation(0)->root_instruction();
410   VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
411   // Identify the gte using `index'.
412   auto find_gte = [](const HloInstruction* conditional_result,
413                      int64_t index) -> HloInstruction* {
414     for (HloInstruction* instr : conditional_result->users()) {
415       if (instr->opcode() != HloOpcode::kGetTupleElement) {
416         return nullptr;
417       }
418       if (instr->tuple_index() == index) {
419         return instr;
420       }
421     }
422     return nullptr;
423   };
424 
425   // Captures tuple indices refering to converts to be rematerialized/hoisted.
426   absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
427       old_root, branch_count, conditional, is_layout_sensitive);
428 
429   // Exit if we cannot find any converts to be hoisted.
430   if (kspecial_convert.empty()) {
431     return false;
432   }
433 
434   TF_RETURN_IF_ERROR(
435       RestructureConditionalInstruction(conditional->parent(), conditional));
436 
437   for (int branch = 0; branch < branch_count; branch++) {
438     old_root = conditional->branch_computation(branch)->root_instruction();
439     absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
440     std::vector<HloInstruction*> new_operands(old_root->operand_count());
441     absl::flat_hash_set<HloInstruction*> to_hoist_set;
442 
443     for (int64_t operand_num = 0; operand_num < old_root->operand_count();
444          ++operand_num) {
445       map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
446           operand_num;
447     }
448     for (int64_t operand_num = 0; operand_num < old_root->operand_count();
449          ++operand_num) {
450       HloInstruction* hoist = old_root->mutable_operand(operand_num);
451       if (!kspecial_convert.contains(operand_num)) {
452         new_operands[operand_num] = old_root->mutable_operand(operand_num);
453         continue;
454       }
455 
456       to_hoist_set.insert(hoist);
457       int64_t new_tuple_count = old_root->operand_count();
458 
459       // Replace the hoisted instr in the tuple with the operand/operands.
460       // We will replace at least one of the operands of the hoist at the
461       // tuple place; the rest will be added at the end.
462       bool inplace = true;
463       CHECK(!hoist->operands().empty());
464       for (HloInstruction* prod : hoist->operands()) {
465         if (inplace) {
466           map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
467           new_operands[map_inst_to_tuple_index[hoist]] = prod;
468           inplace = false;
469         } else {
470           map_inst_to_tuple_index[prod] = new_tuple_count++;
471           new_operands.push_back(prod);
472         }
473       }
474     }
475 
476     // Create the new root instruction.
477     HloComputation* cur_branch = conditional->branch_computation(branch);
478     HloInstruction* new_branch_root =
479         cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
480     // The shape can vary since the operands to convert are now
481     // being returned through the branches' root.
482     cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
483     TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
484 
485     // Only one of the branches needs to change the conditional->parent().
486     if (branch != 0) {
487       continue;
488     }
489     HloComputation* conditional_parent = conditional->parent();
490     HloInstruction* newconditional =
491         conditional_parent->AddInstruction(HloInstruction::CreateConditional(
492             cur_branch->root_instruction()->shape(),
493             conditional->mutable_operand(0),
494             absl::MakeSpan(conditional->branch_computations()),
495             absl::MakeSpan(conditional->operands()).subspan(1)));
496     // Ensure that all the users of conditional refer to the new one.
497     TF_RETURN_IF_ERROR(
498         conditional->ReplaceAllUsesWithDifferentShape(newconditional));
499     TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
500     conditional = newconditional;
501     // Add the hoisted instructions in the parent.
502     for (HloInstruction* hoist : to_hoist_set) {
503       VLOG(2) << "Hoisting instruction:" << hoist->ToString();
504       int64_t hoist_index = map_inst_to_tuple_index[hoist];
505       // Find out the gte that captured the hoisted instr result.
506       HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
507       CHECK(gte_hoist != nullptr);
508       std::vector<HloInstruction*> new_operands;
509       for (HloInstruction* op : hoist->operands()) {
510         HloInstruction* gte = conditional_parent->AddInstruction(
511             HloInstruction::CreateGetTupleElement(op->shape(), conditional,
512                                                   map_inst_to_tuple_index[op]));
513         new_operands.push_back(gte);
514       }
515       HloInstruction* hoisted = conditional_parent->AddInstruction(
516           hoist->CloneWithNewOperands(hoist->shape(), new_operands));
517       VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
518       TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
519       TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
520     }
521     // No need to explicitly delete a hoisted instruction since if its dead
522     // then the subsequent DCE will remove it.
523   }
524   VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
525   return true;
526 }
527 
528 // Hoist identical ops out of the conditional. The definition of identical
529 // are the shape of the operands are identical and their properties are
530 // identical. Will start from the root instruction of each branch and get
531 // the identical ops to hoist.
MoveInstructionOut(HloInstruction * conditional,std::vector<Boundary> & to_move_out,std::vector<Boundary> & new_boundaries)532 StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
533     HloInstruction* conditional, std::vector<Boundary>& to_move_out,
534     std::vector<Boundary>& new_boundaries) {
535   if (to_move_out.empty()) {
536     return false;
537   }
538   VLOG(1) << "Modifying code--number of boundaries to move out:"
539           << to_move_out.size() << "\n";
540   HloComputation* conditional_parent = conditional->parent();
541   // save the old users before add new conditional user instructions
542   std::vector<HloInstruction*> old_conditional_users = conditional->users();
543   // Maps instructions in the conditional body to instructions hoisted outside
544   // the conditional that compute the same value.
545   absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
546   // Insert GetTupleElement before the instructions whose operands might still
547   // be within the conditional.
548   VLOG(1) << "before opt:"
549           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
550           << "\n";
551   int64_t op_index = 0;
552   for (const Boundary& b : new_boundaries) {
553     HloInstruction* op = b.operands()[0];
554     CHECK(op != nullptr);
555     VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
556     HloInstruction* gtr = conditional_parent->AddInstruction(
557         HloInstruction::CreateGetTupleElement(op->shape(), conditional,
558                                               op_index++));
559     Boundary b2(Boundary::Position::kOutsideBranch);
560     b2.mutable_operands().push_back(gtr);
561     hoisted_instructions[op] = b2;
562   }
563   // Copy boundary instructions out of the conditional.
564   // Visit the operands before its users and copy it, so that the copied
565   // user will point to the correct operand.
566   for (int64_t i = to_move_out.size() - 1; i >= 0; i--) {
567     TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(
568         to_move_out[i], 0, conditional_parent, hoisted_instructions));
569   }
570   VLOG(2) << "Done copy branch instructions out\n"
571           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
572           << "\n";
573   // Change original users of the conditional to use the correct operands.
574   HloInstruction* old_root =
575       conditional->branch_computation(0)->root_instruction();
576   for (auto user_instr : old_conditional_users) {
577     VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
578     CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
579     auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
580     int64_t index = tuple_opd->tuple_index();
581     CHECK(old_root->operands().size() > index);
582     HloInstruction* old_opd = old_root->operands()[index];
583     VLOG(2) << "old opd = " << old_opd << "\n";
584     CHECK(ContainsKey(hoisted_instructions, old_opd));
585     HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
586     CHECK(old_opd != nullptr);
587     CHECK(new_opd != nullptr);
588     VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n";
589     TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
590     TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
591   }
592   VLOG(2) << "Done changing conditional users\n"
593           << conditional_parent->ToString() << "\n";
594   // Create tuple element within each branch and set it as root.
595   int64_t branch_count = conditional->branch_count();
596   for (int i = 0; i < branch_count; i++) {
597     auto computation = conditional->branch_computation(i);
598     std::vector<HloInstruction*> elements;
599     for (const auto& b1 : new_boundaries) {
600       HloInstruction* op = b1.operands()[i];
601       CHECK(op != nullptr);
602       VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
603       elements.push_back(op);
604     }
605     HloInstruction* tuple =
606         computation->AddInstruction(HloInstruction::CreateTuple(elements));
607     computation->set_root_instruction(tuple, true);
608     VLOG(2) << "computation is :" << computation->ToString() << "\n";
609     // Remove hoisted instructions from the branches.
610     for (const auto& b2 : to_move_out) {
611       auto instr_to_remove = b2.operands()[i];
612       // Double check to make sure it is safe to delete the instruction.
613       // Complications may arise due to some operations in the alternative
614       // branches (branches 1..n) being placed into the boundaries multiple
615       // times.
616       if (!computation->IsMarkedAsDead(instr_to_remove) &&
617           instr_to_remove->user_count() == 0) {
618         VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
619         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove));
620       }
621     }
622   }
623   // Change conditional instruction shape to the shape of the new root.
624   HloInstruction* new_root =
625       conditional->branch_computation(0)->root_instruction();
626   *conditional->mutable_shape() = new_root->shape();
627   VLOG(1) << "done moving instructions out of branches\n"
628           << conditional_parent->ToString(HloPrintOptions::Fingerprint())
629           << "\n";
630   return true;
631 }
632 
633 // Hoist ops from outside of the conditional to inside the branches.
MoveInstructionIn(HloInstruction * conditional,std::vector<Boundary> & to_move_in,std::vector<Boundary> & new_boundaries)634 StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
635     HloInstruction* conditional, std::vector<Boundary>& to_move_in,
636     std::vector<Boundary>& new_boundaries) {
637   if (to_move_in.empty()) {
638     return false;
639   }
640   VLOG(1) << "Modifying code---number of boundaries to move in:"
641           << to_move_in.size() << "\n";
642   VLOG(1) << "before opt:"
643           << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
644           << "\n";
645   // Mapping instructions to be moved to their new representations.
646   absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
647   int64_t to_move_in_size = to_move_in.size();
648   int64_t branch_count = conditional->branch_count();
649   HloGetTupleElementInstruction* tuple_use =
650       DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
651   // If use_index is -1, the old conditional root entry used by to_move_in
652   // instructions still need to be included as an entry of the modified
653   // conditional root, and the new result of the to_move_in instructions
654   // need to be added as an extra entry of the modified root; otherwise, the
655   // old root entry will be replaced with the new result in the modified root.
656   // The entry replacement should be allowed only if tuple_use has <=1 users.
657   int64_t use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
658                           ? tuple_use->tuple_index()
659                           : -1;
660   VLOG(2) << "Tuple use index = " << use_index << "\n";
661   // Number of old conditional entries still to be used outside.
662   // If conditional shape is not tuple, will create a tuple and use subscript
663   // 0 to save the old operand being used.
664   int64_t op_index =
665       conditional->shape().IsTuple()
666           ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
667                               : conditional->shape().tuple_shapes_size())
668           : 0;
669   // Use to map the tuple_use instruction to its operand;
670   Boundary b_opd_use(Boundary::Position::kInsideBranch);
671   Boundary b_old_root(Boundary::Position::kInsideBranch);
672   // Create a new root instruction in each branch.
673   for (int i = 0; i < branch_count; i++) {
674     auto computation = conditional->branch_computation(i);
675     auto old_root = computation->root_instruction();
676     b_old_root.mutable_operands().push_back(old_root);
677     std::vector<HloInstruction*> operands;
678     if (old_root->opcode() == HloOpcode::kTuple) {
679       // Use operands of old_root directly, so old_root can be removed later.
680       for (int i = 0; i < old_root->operand_count(); ++i) {
681         if (i != use_index) {
682           operands.push_back(old_root->operands()[i]);
683         } else {  // Map conditional use to the tuple operand.
684           b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
685         }
686       }
687     } else if (old_root->shape().IsTuple()) {
688       // If old_root is not a kTuple but has tuple shape, elements within the
689       // tuple must be extracted first to be used by the new instructions.
690       const Shape& old_shape = old_root->shape();
691       for (int64_t i = 0; i < old_shape.tuple_shapes_size(); ++i) {
692         auto element =
693             computation->AddInstruction(HloInstruction::CreateGetTupleElement(
694                 old_shape.tuple_shapes(i), old_root, i));
695         if (i != use_index) {
696           operands.push_back(element);
697         } else {
698           b_opd_use.mutable_operands().push_back(element);
699         }
700       }
701     } else {
702       // If old_root is not a tuple and does not have tuple shape, use it
703       // to replace the conditional directly in the new computation.
704       b_opd_use.mutable_operands().push_back(conditional);
705     }
706 
707     HloInstruction* new_root =
708         computation->AddInstruction(HloInstruction::CreateTuple(operands));
709     VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
710     computation->set_root_instruction(new_root,
711                                       /*accept_different_shape*/ true);
712     if (old_root->opcode() == HloOpcode::kTuple) {
713       TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
714     }
715     VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
716   }
717   // Update get tuple element index of the conditional.
718   if (use_index != -1) {
719     for (auto* user : conditional->users()) {
720       if (user->opcode() == HloOpcode::kGetTupleElement &&
721           user->tuple_index() > use_index) {
722         user->set_tuple_index(user->tuple_index() - 1);
723       }
724     }
725   }
726   hoisted_instructions[conditional] = b_old_root;
727   int64_t cp_start = 0;
728   if (use_index >= 0) {
729     VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
730     hoisted_instructions[tuple_use] = b_opd_use;
731   }
732   cp_start = (tuple_use != nullptr) ? 1 : 0;
733   for (int64_t to_move_index = cp_start; to_move_index < to_move_in_size;
734        to_move_index++) {
735     Boundary b_to_move = to_move_in[to_move_index];
736     HloInstruction* op = b_to_move.operands()[0];
737     CHECK(op != nullptr);
738     bool to_be_used_outside = true;
739     VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
740     if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
741         op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
742       to_be_used_outside = false;
743       VLOG(2) << "Instruction is not to be used outside the branch\n";
744     }
745     Boundary b(Boundary::Position::kInsideBranch);
746     for (int i = 0; i < branch_count; i++) {
747       auto computation = conditional->branch_computation(i);
748       VLOG(2) << "Copying to branch: " << i << "\n";
749       TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation,
750                                                   hoisted_instructions));
751       VLOG(2) << "Done:" << computation->ToString() << "\n";
752       if (to_be_used_outside) {
753         auto new_op = hoisted_instructions[op].operands()[i];
754         auto new_root = computation->root_instruction();
755         new_root->AppendOperand(new_op);
756         *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape();
757         VLOG(2) << "Extending conditional root " << i << " : "
758                 << new_root->ToString() << "\n";
759       }
760       VLOG(2) << "After extending branch root: " << computation->ToString()
761               << "\n";
762     }
763     if (to_be_used_outside) {
764       // Modify uses of instructions outside of the conditionals
765       HloInstruction* gtr = conditional->parent()->AddInstruction(
766           HloInstruction::CreateGetTupleElement(op->shape(), conditional,
767                                                 op_index++));
768       TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
769       if (conditional->parent()->root_instruction() == op) {
770         conditional->parent()->set_root_instruction(gtr);
771       }
772     }
773   }
774   VLOG(2) << "Done copying instructions inside branch: "
775           << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n";
776   // Change conditional instruction shape to the shape of the new root.
777   HloInstruction* new_root =
778       conditional->branch_computation(0)->root_instruction();
779   *conditional->mutable_shape() = new_root->shape();
780   VLOG(2) << "Before removing instructions:"
781           << conditional->parent()->ToString() << "\n";
782   // Remove hoisted instructions from the branches.
783   for (int64_t i = to_move_in_size - 1; i >= 0; i--) {
784     Boundary boundary_to_move_in = to_move_in[i];
785     HloInstruction* op = boundary_to_move_in.operands()[0];
786     if (op->user_count() == 0) {
787       VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
788       TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
789       VLOG(2) << "Done removing boundary.\n";
790     }
791   }
792 
793   // Reset shapes of user gtes to the new shape.
794   if (use_index != -1) {
795     for (auto* user : conditional->users()) {
796       if (user->opcode() == HloOpcode::kGetTupleElement) {
797         VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
798         *user->mutable_shape() =
799             conditional->shape().tuple_shapes(user->tuple_index());
800       }
801     }
802   }
803   VLOG(1) << "Done moving instructions inside branches\n"
804           << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
805           << "\n";
806   return true;
807 }
808 
809 // Group single chains of operands or uses of boundaries into new boundaries
810 class GroupConnectedBoundaries {
811  private:
812   std::vector<Boundary> connected_boundaries_, new_boundaries_;
813   HloInstruction* conditional_;
814   HloComputation* conditional_parent_;
815   bool is_layout_sensitive_;
816   // Instructions that have been visited but are not going to be moved.
817   absl::flat_hash_map<HloInstruction*, int>& visited_count_;
818   // The following four lines are configurations of the cost model, which will
819   // be used to determine whether to move an instruction (move_config_) and how
820   // strongly preferred it is to keep a pair of ops together (reuse_config_).
821   // The search_config_ is used to control how to navigate the search space of
822   // the cost model in the context of auto/manual tuning. The flipped array is
823   // used to save which entries in the configuration have been changed in the
824   // search/tuning process.
825   std::vector<std::vector<int64>>& move_config_;
826   std::vector<std::vector<int64>>& reuse_config_;
827   std::vector<int64>& search_config_vec_;
828   int64* search_config_;
829   int64 search_subscript_;
830   absl::flat_hash_map<const int64*, int64> flipped_;
831 
832   // The FlipMutation function serves to implement the search of alternative
833   // cost models by deciding whether to flip a given configuration, saved in
834   // the loc parameter. The non_zero parameter provides the new value to use
835   // to flip a zero. The msg parameter is only used for debugging purpposes.
FlipMutation(int64 * loc,const int64_t non_zero,const std::string & msg)836   int64 FlipMutation(int64* loc, const int64_t non_zero,
837                      const std::string& msg) {
838     if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
839       VLOG(2) << "Configured not to search or loc is already flipped.";
840       return *loc;
841     }
842     // The last 8 digits control when to start the first flip.
843     int c = ConditionalCodeMotion::flip_start(*search_config_);
844     VLOG(2) << "flip start index = " << c << "\n";
845     // Only flip the decision if c reaches 0.
846     if (c > 0) {
847       (*search_config_)--;
848       return *loc;
849     }
850     // The 8-16 digits control the maximum number of times to flip a config.
851     auto flip_count = ConditionalCodeMotion::DecrementMaxFlip(search_config_);
852     VLOG(2) << "max flip count = " << flip_count << "\n";
853     VLOG(2) << "Updating max Flipping configuration = " << *search_config_
854             << "\n";
855     if (flip_count == 0) {
856       VLOG(2) << "Maximum flip count has reached. ";
857       if (search_subscript_ + 1 < search_config_vec_.size()) {
858         VLOG(2) << "search_subscript_ = " << search_subscript_;
859         VLOG(2) << "search config vec size = " << search_config_vec_.size();
860         search_config_ = &search_config_vec_[++search_subscript_];
861       } else {
862         return *loc;
863       }
864     }
865     // Reload the 16-23 digits of the configuration, which controls how
866     // frequently a configuration should be flipped.
867     auto flip_stride = ConditionalCodeMotion::flip_stride(*search_config_);
868     *search_config_ += flip_stride;
869     VLOG(2) << "flip stride = " << flip_stride << "\n";
870     VLOG(2) << "Updating Flipping Stride = " << *search_config_ << "\n";
871 
872     flipped_[loc] = *loc;
873     // Copy the last 8 bits back to the first 8 bits of configuration.
874     switch (*loc) {
875       case 0:
876         *loc = non_zero;
877         break;
878       default:
879         *loc = 0;
880         break;
881     }
882     VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
883             << " to " << *loc << "\n";
884     return *loc;
885   }
886 
887  public:
GroupConnectedBoundaries(HloInstruction * conditional,bool is_layout_sensitive,absl::flat_hash_map<HloInstruction *,int> & visited_count,std::vector<std::vector<int64>> * move_config,std::vector<std::vector<int64>> * reuse_config,std::vector<int64> * search_config)888   explicit GroupConnectedBoundaries(
889       HloInstruction* conditional, bool is_layout_sensitive,
890       absl::flat_hash_map<HloInstruction*, int>& visited_count,
891       std::vector<std::vector<int64>>* move_config,
892       std::vector<std::vector<int64>>* reuse_config,
893       std::vector<int64>* search_config)
894       : conditional_(conditional),
895         conditional_parent_(conditional->parent()),
896         is_layout_sensitive_(is_layout_sensitive),
897         visited_count_(visited_count),
898         move_config_(*move_config),
899         reuse_config_(*reuse_config),
900         search_config_vec_(*search_config),
901         search_subscript_(0) {
902     VLOG(2) << "Initializing Group Connected Boundaries\n";
903     CHECK_NE(search_config, nullptr);
904     if (search_config_vec_.empty()) {
905       search_config_vec_.push_back(0);
906     }
907     search_config_ = &search_config_vec_[0];
908   }
909   // Returns estimation of potential reuses carried by a given pair of
910   // instructions. Use different integers to classify different levels
911   // of reuses. Assume all instructions can be fused to enable data reuses.
ReusesCarriedBy(HloInstruction * op,HloInstruction * user)912   int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
913     std::vector<int64>& curconfig =
914         reuse_config_[static_cast<uint32>(op->opcode())];
915     // Flip the reuse configuration if tuning the cost model.
916     // When flipping, use -10 if flipping to the default reuse model. Other
917     // values can be specified if needed to fine-control the decision making.
918     int64_t config =
919         ((*search_config_) < 0)
920             ? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
921                            HloOpcodeString(op->opcode()) + "->" +
922                                HloOpcodeString(user->opcode()))
923             : curconfig[static_cast<uint32>(user->opcode())];
924     VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
925             << op->ToString() << "=>" << user->ToString() << " : " << config
926             << "\n";
927     if (config < 0) {
928       // Assume the reuse decreases with increasing user count.
929       int count1 = CountNonLeafOps(op->users());
930       int count2 = CountNonLeafOps(user->operands());
931       return (-config) / count1 / count2;
932     }
933     return config;
934   }
clear_recently_visited()935   void clear_recently_visited() {
936     for (const auto& boundary : new_boundaries_) {
937       visited_count_.erase(boundary.operands()[0]);
938     }
939   }
940   // Returns true if `instruction` is worth hoisting.
WorthHoisting(HloInstruction * instruction,bool is_inside_branch)941   bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
942     // This is needed for the "moving-in" transformation, to prevent the root
943     // of the parent computation (which contains the conditional) to be moved
944     // inside the conditional.
945     HloOpcode opcode = instruction->opcode();
946     if (opcode == HloOpcode::kTuple &&
947         instruction == conditional_parent_->root_instruction()) {
948       return false;
949     }
950     // It is not safe to move collective ops from outside to inside
951     // conditional branches, as it may cause synchronization problems,
952     // when different layouts are assigned to different branches.
953     if (DynCast<HloCollectiveInstruction>(instruction) && !is_inside_branch) {
954       return false;
955     }
956 
957     // It is not legal to move the parameter instructions.
958     if (opcode == HloOpcode::kParameter) {
959       return false;
960     }
961 
962     // Use configuration given from outside (e.g., by autotuner).
963     std::vector<int64>& curconfig = move_config_[static_cast<uint32>(opcode)];
964     auto col = (curconfig.size() == 1) ? 0
965                : (instruction->operand_count() > 0)
966                    ? static_cast<uint32>(instruction->operand(0)->opcode())
967                    : 0;
968     VLOG(2) << "column = " << col << "\n";
969     VLOG(2) << "config size = " << curconfig.size() << "\n";
970     VLOG(2) << "search_config = " << *search_config_ << "\n";
971     CHECK(col < curconfig.size());
972     uint32 config = ((*search_config_) > 0)
973                         ? FlipMutation(&curconfig[col], 1,
974                                        "Move-" + HloOpcodeString(opcode))
975                         : curconfig[col];
976     VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
977     VLOG(2) << "after checking search_config = " << *search_config_ << "\n";
978     return (config != 0);
979   }
980 
ReusesBeforeBoundary(HloInstruction * user)981   int64 ReusesBeforeBoundary(HloInstruction* user) {
982     int64_t reuses = 0;
983     for (auto op : user->operands()) {
984       // The operand must be an instruction that is not going to be moved (if
985       // user is inside the conditional); otherwise it must be the conditional
986       // itself and its user must be outside of the conditional.
987       if (!ContainsKey(visited_count_, op) && op != conditional_) {
988         continue;
989       }
990       if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
991         if (op->opcode() == HloOpcode::kConditional) {
992           auto tuple = op->branch_computation(0)->root_instruction();
993           if (tuple->opcode() == HloOpcode::kTuple) {
994             auto index = tuple_gte->tuple_index();
995             CHECK(index < tuple->operand_count());
996             op = tuple->mutable_operand(index);
997           }
998         }
999         reuses += ReusesCarriedBy(op, user->users()[0]);
1000       } else {
1001         reuses += ReusesCarriedBy(op, user);
1002       }
1003     }
1004     VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
1005             << "\n";
1006     return reuses;
1007   }
1008 
ReusesAfterBoundary(HloInstruction * user)1009   int64 ReusesAfterBoundary(HloInstruction* user) {
1010     CHECK(user != nullptr);
1011     auto all_users = user->users();
1012     // For now, assume that if an instruction has multiple-consumers, it
1013     // will not be reused, as the reuse may require duplication in
1014     // fusion and so is expensive. If the situation changes in the future,
1015     // some aspects of the overall algorithm need to be redesigned to
1016     // accommandate the change.
1017     if (all_users.size() > 1) {
1018       VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
1019       return 0;
1020     }
1021     if (!all_users.empty()) {
1022       auto op = all_users[0];
1023       int64_t reuses = 0;
1024       // Only count reuses that run through the conditional root.
1025       if (op == conditional_->branch_computation(0)->root_instruction()) {
1026         int64_t index = op->operand_index(user);
1027         for (auto op2 : conditional_->users()) {
1028           // If the use is not get tuple, right now do not consider it.
1029           if (op2->opcode() == HloOpcode::kGetTupleElement) {
1030             auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2);
1031             if (index == tuple_opd->tuple_index()) {
1032               all_users = op2->users();
1033               if (!all_users.empty()) {
1034                 reuses += ReusesCarriedBy(user, all_users[0]);
1035                 break;
1036               }
1037             }
1038           }
1039         }
1040       } else if (ContainsKey(visited_count_, op)) {
1041         reuses += ReusesCarriedBy(user, op);
1042       }
1043       VLOG(2) << "reuses after instruction " << user->ToString() << ":"
1044               << reuses << "\n";
1045       return reuses;
1046     }
1047     return 0;
1048   }
1049 
BenefitForMovingBoundaries(const std::vector<Boundary> & boundaries,bool perform_reuse_analysis=true)1050   int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries,
1051                                    bool perform_reuse_analysis = true) {
1052     int64_t reuses_before = 0, reuses_after = 0;
1053     if (boundaries.size() == 1) {
1054       if (boundaries[0].IsOutsideBranch() &&
1055           boundaries[0].operands()[0]->opcode() ==
1056               HloOpcode::kGetTupleElement) {
1057         // The only boundary of moving-in is the get_tuple_element op.
1058         return -1;
1059       }
1060       if (boundaries[0].IsInsideBranch() &&
1061           boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
1062         // The only boundary of moving-out is the tuple op inside branches.
1063         return -1;
1064       }
1065     }
1066     // If trying alternative moving configurations, turn off reuse analysis.
1067     if (!perform_reuse_analysis) {
1068       return 1;
1069     }
1070     // For cases like :
1071     // branch0 {
1072     //   ROOT copy
1073     // }
1074     // branch1 {
1075     //   ...
1076     // }
1077     // cond = conditional(branch0, branch1)
1078     // copy = copy(cond)
1079     //
1080     // We can fold the two copies thus reducing computation.
1081     auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64 {
1082       if (hlo->opcode() != HloOpcode::kCopy) {
1083         return 0;
1084       }
1085       const HloGetTupleElementInstruction* gte =
1086           DynCast<HloGetTupleElementInstruction>(hlo->operand(0));
1087       if (gte == nullptr) {
1088         return 0;
1089       }
1090       const HloInstruction* conditional = gte->operand(0);
1091       if (conditional != conditional_) {
1092         return 0;
1093       }
1094       int64_t benefit = 0;
1095       for (auto* branch : conditional->called_computations()) {
1096         HloInstruction* root = branch->root_instruction();
1097         if (root->opcode() == HloOpcode::kTuple) {
1098           const auto* tuple_operand = root->operand(gte->tuple_index());
1099           if (tuple_operand->opcode() == HloOpcode::kCopy) {
1100             if (Shape::Equal()(tuple_operand->operand(0)->shape(),
1101                                hlo->shape())) {
1102               benefit += 10;
1103             }
1104           }
1105         }
1106       }
1107       return benefit;
1108     };
1109     for (const Boundary& b : boundaries) {
1110       auto op = b.operands()[0];
1111       if (op == conditional_->branch_computation(0)->root_instruction()) {
1112         continue;
1113       }
1114       VLOG(2) << "Benefit for " << op->ToString();
1115       reuses_before += ReusesBeforeBoundary(op);
1116       VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
1117       reuses_after += ReusesAfterBoundary(op);
1118       VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
1119     }
1120 
1121     int64_t copy_folding_benefit = 0;
1122     if (boundaries[0].IsOutsideBranch()) {
1123       for (const Boundary& b : boundaries) {
1124         auto op = b.operands()[0];
1125         copy_folding_benefit += get_copy_folding_benefit(op);
1126       }
1127     }
1128     VLOG(2) << "Copy folding benefit: " << copy_folding_benefit;
1129 
1130     if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) {
1131       return -1;
1132     } else if (boundaries[0].IsInsideBranch()) {
1133       return reuses_after - reuses_before;
1134     } else {
1135       return reuses_before - reuses_after - 1 + copy_folding_benefit;
1136     }
1137   }
1138 
GetNextBoundary(const Boundary & b,int64_t op_index)1139   Boundary GetNextBoundary(const Boundary& b, int64_t op_index) {
1140     Boundary b2(b.GetPosition());
1141     for (int j = 0; j < b.operands().size(); ++j) {
1142       HloInstruction* inst = b.operands()[j];
1143       CHECK(inst != nullptr);
1144       HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
1145                                                 : inst->users()[op_index];
1146       CHECK(op != nullptr);
1147       b2.mutable_operands().push_back(op);
1148     }
1149     return b2;
1150   }
1151 
1152   // Checking whether it is safe to move a boundary when visited through a
1153   // dependent already considered for moving.
IsSafeToMoveBoundary(const Boundary & next_boundary)1154   bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
1155     int64_t next_boundary_count =
1156         (next_boundary.IsInsideBranch())
1157             ? next_boundary.operands()[0]->user_count()
1158             : CountNonLeafOps(next_boundary.operands()[0]->operands());
1159     if (next_boundary_count <= 1) {
1160       // If boundary has only a single or no dependent, safe to move.
1161       return true;
1162     } else {
1163       if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
1164         VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
1165                 << " because it has multiple dependents: "
1166                 << next_boundary_count << "\n";
1167         visited_count_[next_boundary.operands()[0]] = 1;
1168         new_boundaries_.push_back(next_boundary);
1169       } else {
1170         auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
1171                              next_boundary);
1172         if (pos != new_boundaries_.end() ||
1173             next_boundary.operands().size() == 1) {
1174           int count = ++visited_count_[next_boundary.operands()[0]];
1175           if (count == next_boundary_count) {
1176             VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
1177                     << "\n"
1178                     << " because all of its dependents have been visited: "
1179                     << next_boundary_count << "\n";
1180             visited_count_.erase(next_boundary.operands()[0]);
1181             if (pos != new_boundaries_.end()) {
1182               new_boundaries_.erase(pos);
1183             }
1184             return true;
1185           }
1186         } else {
1187           VLOG(2) << "Skip incompatible multi-dependent boundary: "
1188                   << next_boundary.ToString() << ":" << next_boundary_count
1189                   << "\n";
1190         }
1191       }
1192     }
1193     return false;
1194   }
1195   // This function is reused both for moving the boundary outside or into a
1196   // conditional. As the result, the readability is somewhat compromised.
1197   // It might be nice to refactor this function to factor the outside-inside
1198   // considerations into separate function pointer parameters to improve
1199   // readability.
AddBoundaries(const Boundary & boundary)1200   void AddBoundaries(const Boundary& boundary) {
1201     BoundaryVisitor visitor;
1202     visitor.AddToWorkList(boundary);
1203     while (visitor.HasNextBoundary()) {
1204       Boundary b = visitor.PopNextBoundary();
1205       VLOG(2) << "visiting boundary " << b.ToString() << "\n";
1206       if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
1207                                       b.operands(), is_layout_sensitive_)) &&
1208           IsSafeToMoveBoundary(b) &&
1209           WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
1210         connected_boundaries_.push_back(b);
1211         VLOG(2) << "boundary can be moved\n";
1212         int64_t operand_count = (b.IsInsideBranch())
1213                                     ? b.operands()[0]->operand_count()
1214                                     : b.operands()[0]->users().size();
1215         for (int i = 0; i < operand_count; i++) {
1216           Boundary next_boundary = GetNextBoundary(b, i);
1217           VLOG(2) << "Add operand/user " << i << " to visit later\n";
1218           visitor.AddToWorkList(next_boundary);
1219         }
1220       } else {
1221         VLOG(2) << "boundary cannot be moved\n";
1222         visited_count_[b.operands()[0]] = 1;
1223         new_boundaries_.push_back(b);
1224       }
1225     }
1226   }
BoundariesToMoveInOrOut(HloInstruction * conditional,const Boundary & b)1227   std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
1228                                                 const Boundary& b) {
1229     // At the beginning of optimization, a conditional itself is added to a
1230     // worklist. Here the conditional is expanded into two sets of boundaries:
1231     // the first set contains the boundary that is inside branches and
1232     // contains the root of all branches; the second set of boundaries
1233     // contains all the users of the conditional.
1234     HloInstruction* inst = b.operands()[0];
1235     if (inst == conditional) {
1236       int branch_count = inst->branch_count();
1237       // Add conditional roots as a new boundary to visit.
1238       Boundary boundary_in(Boundary::Position::kInsideBranch);
1239       for (int i = 0; i < branch_count; i++) {
1240         HloComputation* branch_computation = inst->branch_computation(i);
1241         HloInstruction* root_inst = branch_computation->root_instruction();
1242         CHECK(root_inst != nullptr);
1243         boundary_in.mutable_operands().push_back(root_inst);
1244       }
1245       new_boundaries_.push_back(boundary_in);
1246       // Add conditional users as new boundaries to visit.
1247       for (auto u : inst->users()) {
1248         Boundary boundary_in(Boundary::Position::kOutsideBranch);
1249         boundary_in.mutable_operands().push_back(u);
1250         new_boundaries_.push_back(boundary_in);
1251       }
1252     } else {
1253       AddBoundaries(b);
1254     }
1255     return connected_boundaries_;
1256   }
AddNewBoundaries(std::vector<Boundary> & b)1257   void AddNewBoundaries(std::vector<Boundary>& b) {
1258     b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end());
1259   }
1260 };
1261 
ConsiderCodeMotion(HloInstruction * conditional,const Boundary & cur_boundary,std::vector<Boundary> & to_move,std::vector<Boundary> & new_boundaries,absl::flat_hash_map<HloInstruction *,int> & visited_count)1262 ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
1263     HloInstruction* conditional, const Boundary& cur_boundary,
1264     std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
1265     absl::flat_hash_map<HloInstruction*, int>& visited_count) {
1266   GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
1267                                    visited_count, &move_config_, &reuse_config_,
1268                                    &search_config_);
1269   auto move_in_or_out =
1270       connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
1271   if (!move_in_or_out.empty()) {
1272     auto benefit = connect.BenefitForMovingBoundaries(
1273         move_in_or_out, search_config_map_.empty());
1274     VLOG(2) << "benefit of moving in or out "
1275             << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
1276     if (benefit >= 0) {
1277       new_boundaries.clear();
1278       connect.AddNewBoundaries(new_boundaries);
1279       // The whole sequence in move_in_or_out is either all moving into a
1280       // conditional, or all moving out of a conditional. So looking only
1281       // at the first entry of the sequence is sufficient to know which
1282       // direction the move is intended.
1283       to_move = move_in_or_out;
1284       return Decision(to_move[0].IsInsideBranch()
1285                           ? Decision::Direction::kMoveOutOfBranch
1286                           : Decision::Direction::kMoveIntoBranch,
1287                       benefit);
1288     } else {
1289       connect.clear_recently_visited();
1290     }
1291   } else {
1292     connect.AddNewBoundaries(new_boundaries);
1293   }
1294   return Decision(Decision::Direction::kNoChange, 0);
1295 }
1296 
Run(HloModule * module)1297 StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
1298   VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
1299   // Use to support debugging of optimization, by disabling the opt after it has
1300   // been applied a pre-determined times (to isolate impact of transformations).
1301   if (!ConsumeFuel("conditional_code_motion", [&] {
1302         return "Skipping conditional opt after allowed limit reaching 0.\n";
1303       })) {
1304     return false;
1305   }
1306   bool changed = false;
1307   bool cleanup_changed = false;
1308   {
1309     HloPassPipeline subpipeline("before_conditional_code_motion");
1310     subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_);
1311     subpipeline.AddPass<HloDCE>();
1312     TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1313     cleanup_changed |= cleanup_changed_now;
1314   }
1315   // Gather all the conditional ops in the module ahead of time, to avoid
1316   // potential complications of modifying the code that affecting traversal.
1317   std::vector<HloInstruction*> conditional_ops;
1318   // Track how many times each branch computation is shared.
1319   absl::flat_hash_map<HloComputation*, int> conditional_computations;
1320   for (auto* comp : module->MakeComputationPostOrder()) {
1321     for (auto* instr : comp->MakeInstructionPostOrder()) {
1322       if (instr->opcode() == HloOpcode::kConditional) {
1323         int branch_count = instr->branch_count();
1324         for (int i = 0; i < branch_count; ++i) {
1325           HloComputation* branch_i = instr->branch_computation(i);
1326           if (ContainsKey(conditional_computations, branch_i)) {
1327             conditional_computations[branch_i]++;
1328           } else {
1329             conditional_computations[branch_i] = 0;
1330           }
1331         }
1332         if (instr->shape().IsTuple()) {
1333           bool can_change_tuple_shape = true;
1334           for (auto user : instr->users()) {
1335             VLOG(2) << "user is : " << user->ToString() << "\n";
1336             if (user->opcode() != HloOpcode::kGetTupleElement) {
1337               can_change_tuple_shape = false;
1338             }
1339           }
1340           if (can_change_tuple_shape) {
1341             conditional_ops.push_back(instr);
1342           }
1343         } else {
1344           conditional_ops.push_back(instr);
1345         }
1346       }
1347     }
1348   }
1349 
1350   int64_t conditional_index = 0;
1351   // Use to collect mappings between cloned instructions.
1352   HloCloneContext clone_context(module);
1353   for (HloInstruction* conditional : conditional_ops) {
1354     if (conditional_index == 0 || !search_config_map_.empty()) {
1355       auto config_entry = search_config_map_.find(conditional_index);
1356       if (config_entry != search_config_map_.end()) {
1357         search_config_ = (*config_entry).second;
1358         VLOG(2) << "config entry value extracted:" << search_config_.size();
1359         search_config_index_ = 0;
1360       }
1361       VLOG(2) << "Obtaining default configuration for conditional "
1362               << conditional_index << "\n";
1363       SetDefaultMoveConfig();
1364       VLOG(2) << "Done obtaining default configuration\n";
1365       conditional_index++;
1366     }
1367     int branch_count = conditional->branch_count();
1368     // check for shared conditional computations
1369     bool conditional_is_shared = false;
1370     for (int i = 0; i < branch_count; ++i) {
1371       HloComputation* branch_i = conditional->branch_computation(i);
1372       if (conditional_computations[branch_i] > 0) {
1373         conditional_is_shared = true;
1374         break;
1375       }
1376     }
1377 
1378     // Boundaries to move out or to move into the branches.
1379     std::vector<std::vector<Boundary>> to_move_out, to_move_in;
1380     std::vector<std::vector<Boundary>> new_boundaries_for_moveout;
1381     std::vector<std::vector<Boundary>> new_boundaries_for_movein;
1382     // Number of times each instruction has been visited for moving.
1383     absl::flat_hash_map<HloInstruction*, int> visited_count;
1384     int benefit_move_out = 0, benefit_move_in = 0;
1385     Decision::Direction final_d = Decision::Direction::kNoChange;
1386     // The conditional is moved into a worklist as the seed (starting point).
1387     // The conditional will be expanded into multiple seeds (starting points),
1388     // its roots and its users, when it is visited by GroupConnectedBoundaries.
1389     // A NO_CHANGE decision will always be returned for the conditional itself,
1390     // so that the other seeding boundaries can be visited in turn.
1391     BoundaryVisitor visitor(conditional);
1392     VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
1393     // Try visit all the boundaries, collect the analysis results, and save
1394     // all the benefitical non-conflicting decisions. If two decisions conflict
1395     // with each other, save the more benefitical one.
1396     while (visitor.HasNextBoundary()) {
1397       std::vector<Boundary> to_move, next_boundary;
1398       Boundary boundary = visitor.PopNextBoundary();
1399       VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
1400       auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
1401                                   visited_count);
1402       switch (d.GetDirection()) {
1403         case Decision::Direction::kMoveOutOfBranch:
1404           VLOG(2) << "Local Decision is move out of branch\n";
1405           to_move_out.push_back(to_move);
1406           new_boundaries_for_moveout.push_back(next_boundary);
1407           benefit_move_out += d.GetBenefit();
1408           if (benefit_move_out >= benefit_move_in) {
1409             final_d = Decision::Direction::kMoveOutOfBranch;
1410             VLOG(2) << "Current Decision is move out of branch ("
1411                     << to_move_out.size() << ")\n";
1412           } else {
1413             VLOG(2) << "Current Decision remains move into branch\n";
1414           }
1415           break;
1416         case Decision::Direction::kMoveIntoBranch:
1417           VLOG(2) << "Decision is move into branch\n";
1418           to_move_in.push_back(to_move);
1419           new_boundaries_for_movein.push_back(next_boundary);
1420           benefit_move_in += d.GetBenefit();
1421           if (benefit_move_out >= benefit_move_in) {
1422             VLOG(2) << "Current Decision remains move out of branch\n";
1423           } else {
1424             final_d = Decision::Direction::kMoveIntoBranch;
1425             VLOG(2) << "Current Decision is move into branch ("
1426                     << to_move_in.size() << ")\n";
1427           }
1428           break;
1429         case Decision::Direction::kNoChange:
1430           VLOG(2) << "Decision is no change\n";
1431           for (const Boundary& b : next_boundary) {
1432             visitor.AddToWorkList(b);
1433             VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
1434                     << "\n";
1435           }
1436           break;
1437       }
1438     }
1439     // If modification is to be made, need to clone the shared branches.
1440     if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
1441       for (int i = 0; i < branch_count; ++i) {
1442         HloComputation* branch_i = conditional->branch_computation(i);
1443         if (conditional_computations[branch_i] > 0) {
1444           // Cloning is absolutely needed if the computation is shared by
1445           // different branches, but the cloning can be potentially avoided
1446           // if the sharing is only among branches of the same conditional.
1447           // If cloning these branches causes a problem due to space issues,
1448           // a fix can pass a vector of unique branches to the actual
1449           // transformations, as an alternative representation of the
1450           // conditional branches to be modified. Right now we assume the
1451           // overhead of cloning is minimal since later stages of the compiler
1452           // inline all the computations anyway.
1453           HloComputation* clone_i =
1454               conditional->parent()->parent()->AddEmbeddedComputation(
1455                   branch_i->Clone("clone", &clone_context));
1456           conditional->set_branch_computation(i, clone_i);
1457           conditional_computations[branch_i]--;
1458           // Need to translate the analysis result to generate correct result.
1459           auto update_boundary = [&](Boundary& boundary) {
1460             auto cloned_instr =
1461                 clone_context.FindInstruction(boundary.operands()[i]);
1462             CHECK(cloned_instr != nullptr);
1463             VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
1464                     << "\n";
1465             boundary.mutable_operands()[i] = cloned_instr;
1466             VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
1467                     << "\n";
1468           };
1469           // Only boundaries to move out need to be updated.
1470           if (final_d == Decision::Direction::kMoveOutOfBranch) {
1471             for (int i = 0; i < to_move_out.size(); ++i) {
1472               std::vector<Boundary>& m = to_move_out[i];
1473               std::for_each(m.begin(), m.end(), update_boundary);
1474             }
1475             for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
1476               std::vector<Boundary>& m = new_boundaries_for_moveout[i];
1477               std::for_each(m.begin(), m.end(), update_boundary);
1478             }
1479           }
1480         }
1481       }
1482       VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
1483               << "\n";
1484     }
1485     // At most one of to_move_out or to_move_in can be non-empty, since there is
1486     // only one optimization decision.
1487     if (final_d == Decision::Direction::kMoveOutOfBranch) {
1488       CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
1489       for (int i = 0; i < to_move_out.size(); ++i) {
1490         TF_ASSIGN_OR_RETURN(bool result,
1491                             MoveInstructionOut(conditional, to_move_out[i],
1492                                                new_boundaries_for_moveout[i]));
1493         changed |= result;
1494       }
1495       VLOG(2) << "Done moving out of branches " << to_move_out.size()
1496               << " times. \n";
1497       if (!ConsumeFuel("conditional_code_motion", [&] {
1498             return "Skipping conditional opt after allowed limit reaching 0.\n";
1499           })) {
1500         break;
1501       }
1502     } else if (final_d == Decision::Direction::kMoveIntoBranch) {
1503       CHECK(to_move_in.size() == new_boundaries_for_movein.size());
1504       for (int i = 0; i < to_move_in.size(); ++i) {
1505         TF_ASSIGN_OR_RETURN(bool result,
1506                             MoveInstructionIn(conditional, to_move_in[i],
1507                                               new_boundaries_for_movein[i]));
1508         changed |= result;
1509       }
1510       VLOG(2) << "Done moving into branches " << to_move_in.size()
1511               << " times. \n";
1512       if (!ConsumeFuel("conditional_code_motion", [&] {
1513             return "Skipping conditional opt after allowed limit reaching 0.\n";
1514           })) {
1515         break;
1516       }
1517     } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
1518       // Invoke special handling for convert rematerialization/hoisting
1519       // We need to make sure no sharing is present in the branches because no
1520       // cloning has been done by the earlier analysis.
1521       // TOOD[b/165848866]: extend solution to handle cloning for special move.
1522       TF_ASSIGN_OR_RETURN(
1523           bool convert_result,
1524           ConvertSpecialMove(conditional, is_layout_sensitive_));
1525       if (convert_result) {
1526         VLOG(2) << "Done special moving of convert\n";
1527         if (!ConsumeFuel("conditional_code_motion", [&] {
1528               return "Skipping conditional opt after allowed limit reaching "
1529                      "0.\n";
1530             })) {
1531           break;
1532         }
1533       }
1534       changed |= convert_result;
1535     }
1536   }
1537   if (changed) {
1538     HloPassPipeline subpipeline(
1539         "after_conditional_code_motion_after_convert_hoisting");
1540     VLOG(2) << "starting after motion passes: DCE\n";
1541     subpipeline.AddPass<HloDCE>();
1542     subpipeline.AddPass<TupleSimplifier>();
1543     subpipeline.AddPass<HloDCE>();
1544     TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1545     cleanup_changed |= cleanup_changed_now;
1546   }
1547   if (cleanup_changed) {
1548     VLOG(2) << "subpipeline cleanup have modified code\n";
1549   }
1550   return changed;
1551 }
1552 
SetDefaultMoveConfig()1553 void ConditionalCodeMotion::SetDefaultMoveConfig() {
1554   VLOG(2) << "search_config_index = " << search_config_index_ << "\n";
1555   VLOG(2) << "search_config_ size = " << search_config_.size() << "\n";
1556   int64_t cur_search_config = (search_config_index_ < 0 ||
1557                                search_config_index_ >= search_config_.size())
1558                                   ? 0
1559                                   : search_config_[search_config_index_];
1560   enum class TuningOption {
1561     kDoNotTune = 0,
1562     kTuneTransformationDecision = 1,
1563     kTuneReuseModel = 2,
1564   };
1565   TuningOption tuning_option =
1566       (cur_search_config == 0)  ? TuningOption::kDoNotTune
1567       : (cur_search_config > 0) ? TuningOption::kTuneTransformationDecision
1568                                 : TuningOption::kTuneReuseModel;
1569 
1570   auto row = HloOpcodeCount();
1571   auto col = row;
1572   VLOG(2) << "Start setting default configuration\n";
1573   reuse_config_.clear();
1574   move_config_.clear();
1575   reuse_config_.reserve(row);
1576   move_config_.reserve(row);
1577   for (int64_t opcode = 0; opcode < row; ++opcode) {
1578     // To save whether an instruction is preferred to be moved.
1579     std::vector<int64> reuse_vec(col, 0);
1580     for (uint32 j = 0; j < col; ++j) {
1581       reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
1582                                      static_cast<HloOpcode>(j));
1583     }
1584     reuse_config_.push_back(reuse_vec);
1585     std::vector<int64> move_vec;
1586     switch (tuning_option) {
1587       case TuningOption::kTuneTransformationDecision:
1588         // Tuning transformation decision --- start with all yes.
1589         // Only a single entry is needed if we don't consider operands of an op
1590         // when searching/tuning transformation decisions.
1591         move_vec.push_back(1);
1592         break;
1593         // Tune the ReusesCarriedBy results only.
1594       case TuningOption::kTuneReuseModel:
1595       case TuningOption::kDoNotTune:
1596         // No tuning --- use the default configuration.
1597         // Use the opcode of first operand to configure default.
1598         move_vec.reserve(col);
1599         for (uint32 j = 0; j < col; ++j) {
1600           move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
1601                                            static_cast<HloOpcode>(j)));
1602         }
1603         break;
1604     }
1605     move_config_.push_back(move_vec);
1606   }
1607 }
1608 
1609 // The search configuration is specified using a string in the format of
1610 // 'config1;config2; ...;config_n', where each config_i is in the format of
1611 // 'index,start,max,stride' (four integers separated by comma), which specify
1612 // the index number of the conditional being configured, the index of the first
1613 // transformation decision to flip for the conditional, the max number of
1614 // decisions to flip, and how many decisions to skip in between the flips.
ParseSearchConfiguration(const std::string & search_config)1615 void ConditionalCodeMotion::ParseSearchConfiguration(
1616     const std::string& search_config) {
1617   if (search_config.empty()) {
1618     return;
1619   }
1620   search_config_index_ = 0;
1621   std::vector<std::string> configs = absl::StrSplit(search_config, ';');
1622   for (const std::string& config : configs) {
1623     std::vector<std::string> specs = absl::StrSplit(config, ',');
1624     CHECK_EQ(specs.size(), 4);
1625     int64_t condition_index;
1626     CHECK(absl::SimpleAtoi(specs[0], &condition_index));
1627     auto& cur_config_entry = search_config_map_[condition_index];
1628     int64_t flip_start, max_flip, flip_stride;
1629     CHECK(absl::SimpleAtoi(specs[1], &flip_start));
1630     CHECK(absl::SimpleAtoi(specs[2], &max_flip));
1631     CHECK(absl::SimpleAtoi(specs[3], &flip_stride));
1632     int64_t cur_config = MakeSearchConfig(flip_start, max_flip, flip_stride);
1633     cur_config_entry.push_back(cur_config);
1634     VLOG(2) << "Setting search config " << condition_index << "->" << cur_config
1635             << "\n";
1636   }
1637 }
1638 
1639 }  // namespace conditional_opt
1640 
1641 }  // namespace xla
1642