1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/conditional_code_motion.h"
17
18 #include <iterator>
19 #include <stack>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/str_cat.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/service/call_graph.h"
31 #include "tensorflow/compiler/xla/service/call_inliner.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_cse.h"
35 #include "tensorflow/compiler/xla/service/hlo_dce.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
40 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
41 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/status_macros.h"
44 #include "tensorflow/compiler/xla/statusor.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/platform/errors.h"
49
50 namespace xla {
51
52 namespace conditional_opt {
53
54 class BoundaryVisitor {
55 public:
56 // start with an existing conditional computation.
BoundaryVisitor(HloInstruction * conditional)57 explicit BoundaryVisitor(HloInstruction* conditional) {
58 Boundary b(Boundary::Position::kInsideBranch);
59 b.mutable_operands().push_back(conditional);
60 worklist_.push_back(b);
61 }
62 // Start with an empty work list.
BoundaryVisitor()63 BoundaryVisitor() {}
64 // Get next boundary to visit.
PopNextBoundary()65 Boundary PopNextBoundary() {
66 CHECK(!worklist_.empty());
67 Boundary b = worklist_.front();
68 worklist_.pop_front();
69 // if b is already visited, it must have multiple users and is already in
70 // new boundaries. Skip it. Only checking the first operand of b because b
71 // is expected to have at least one operand, and all the operands in b
72 // must be identical instructions from different branches for b to be moved.
73 while (!worklist_.empty() && ContainsKey(visited_, b.operands()[0])) {
74 b = worklist_.front();
75 worklist_.pop_front();
76 }
77 visited_.insert(b.operands()[0]);
78 return b;
79 }
AddToWorkList(const Boundary & b)80 void AddToWorkList(const Boundary& b) {
81 CHECK(!b.operands().empty());
82 worklist_.push_back(b);
83 }
84
HasNextBoundary()85 bool HasNextBoundary() {
86 while (!worklist_.empty()) {
87 Boundary b = worklist_.front();
88 if (!ContainsKey(visited_, b.operands()[0])) {
89 break;
90 }
91 worklist_.pop_front();
92 }
93 return !worklist_.empty();
94 }
95
96 private:
97 // worklist is the deque that contains instructions to be visited.
98 std::deque<Boundary> worklist_;
99 absl::flat_hash_set<HloInstruction*> visited_;
100 };
101
102 template <class OpCollection>
CountNonLeafOps(const OpCollection & ops)103 int64 CountNonLeafOps(const OpCollection& ops) {
104 absl::flat_hash_set<HloInstruction*> op_set;
105 for (auto op : ops) {
106 if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) {
107 op_set.insert(op);
108 }
109 }
110 return op_set.size();
111 }
112
113 // Returns estimation of potential reuses carried by a given pair of
114 // instructions. Use different integers to classify different levels
115 // of reuses This is used as a placeholder only, assuming all
116 // instructions can be fused to enable data reuses
ReusesCarriedBy(HloOpcode op,HloOpcode user)117 int64 ReusesCarriedBy(HloOpcode op, HloOpcode user) {
118 // Reuses in some way work like forces that pull instructions
119 // towards each other. We use a number 0-10 to classify how strong the force
120 // is between a pair of operations. Given a group of instructions that can be
121 // moved together, if the forces inside a conditional are stronger, the group
122 // will be moved incide or remain inside the conditional; otherwise, it will
123 // be moved outside to or remain outside of the conditional.
124 switch (user) {
125 case HloOpcode::kGetTupleElement:
126 return 0;
127 case HloOpcode::kConvert:
128 // Because convert is treated not moveable when following Dot or
129 // convolution, here if op is dot or convolution, they must be separated
130 // by a conditional boundary. Here we do not try to pull convert inside
131 // conditionals to be together with the dot or convolution.
132 switch (op) {
133 case HloOpcode::kConvolution:
134 case HloOpcode::kDot:
135 return 0;
136 default:
137 break;
138 }
139 break;
140 default:
141 break;
142 }
143 switch (op) {
144 // These instructions do not carry weight of reuse themselves.
145 case HloOpcode::kParameter:
146 case HloOpcode::kConstant:
147 case HloOpcode::kGetTupleElement:
148 return 0;
149 case HloOpcode::kConditional:
150 return 10;
151 default:
152 return -10;
153 }
154 }
155
156 // Returns true if `op` is worth hoisting.
WorthHoisting(HloOpcode op,HloOpcode child_op)157 bool WorthHoisting(HloOpcode op, HloOpcode child_op) {
158 // TOOD[b/169182921] The following cost model is rather incomplete. Will
159 // need to extend to cover most of element-wise ops.
160 switch (op) {
161 case HloOpcode::kConvert:
162 // If Convert is after AllReduce, it is worth moving out AllReduce
163 // out of conditional for AR/CRS combine. If Convert is after other
164 // ops such as Dot or Convolutional, it is better to keep convert
165 // within conditional so that convert can be fused with Dot or
166 // Convolutional.
167 switch (child_op) {
168 case HloOpcode::kAllReduce:
169 case HloOpcode::kReshape:
170 case HloOpcode::kGetTupleElement:
171 return true;
172 default:
173 return false;
174 }
175 case HloOpcode::kGetTupleElement:
176 switch (child_op) {
177 // do not move GTE if its operand is a parameter
178 case HloOpcode::kParameter:
179 return false;
180 default:
181 return true;
182 }
183 case HloOpcode::kAllReduce:
184 case HloOpcode::kReduceScatter:
185 case HloOpcode::kAbs:
186 case HloOpcode::kReduce:
187 case HloOpcode::kAdd:
188 case HloOpcode::kPower:
189 case HloOpcode::kCopy:
190 case HloOpcode::kConstant:
191 case HloOpcode::kSubtract:
192 case HloOpcode::kMultiply:
193 case HloOpcode::kDivide:
194 case HloOpcode::kTuple:
195 case HloOpcode::kSqrt:
196 case HloOpcode::kRsqrt:
197 case HloOpcode::kReshape:
198 case HloOpcode::kMinimum:
199 case HloOpcode::kMaximum:
200 return true;
201 default:
202 return false;
203 }
204 }
205
206 // Compare if the instructions to be visited at each branches are identical.
InstructionWithinBranchIdentical(const std::vector<HloInstruction * > & instructions,bool is_layout_sensitive)207 bool InstructionWithinBranchIdentical(
208 const std::vector<HloInstruction*>& instructions,
209 bool is_layout_sensitive) {
210 // Identical includes the shape of each operands are equal.
211 auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
212 bool eq_operands = is_layout_sensitive
213 ? ShapeUtil::Equal(a->shape(), b->shape())
214 : ShapeUtil::Compatible(a->shape(), b->shape());
215 return eq_operands;
216 };
217
218 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
219 return *a == *b;
220 };
221
222 if (instructions.empty()) {
223 return false;
224 }
225
226 if (instructions[0]->IsCrossModuleAllReduce()) {
227 return std::all_of(
228 instructions.begin(), instructions.end(),
229 [&](HloInstruction* instruction) {
230 if (!instruction->IsCrossModuleAllReduce()) {
231 return false;
232 }
233 auto old_channel_id = instruction->channel_id();
234 instruction->set_channel_id(instructions[0]->channel_id());
235 bool eq_instructions = instructions[0]->Identical(
236 *instruction, eq_operand, eq_computations, is_layout_sensitive);
237 instruction->set_channel_id(old_channel_id);
238 return eq_instructions;
239 });
240 }
241
242 return std::all_of(instructions.begin(), instructions.end(),
243 [&](HloInstruction* instruction) {
244 return instructions[0]->Identical(
245 *instruction, eq_operand, eq_computations,
246 is_layout_sensitive);
247 });
248 }
249
250 // Copy the ith instruction in boundary to outside of conditional, or do the
251 // opposite (for moving in).
CopyInOrOutOfConditional(Boundary & boundary,int64_t dest_index,HloComputation * parent,absl::flat_hash_map<HloInstruction *,Boundary> & hoisted_instructions)252 Status CopyInOrOutOfConditional(
253 Boundary& boundary, int64_t dest_index, HloComputation* parent,
254 absl::flat_hash_map<HloInstruction*, Boundary>& hoisted_instructions) {
255 CHECK(dest_index == 0 || boundary.IsOutsideBranch());
256 HloInstruction* op = boundary.operands()[0];
257 absl::InlinedVector<HloInstruction*, 4> new_operands;
258 for (int i = 0; i < op->operands().size(); ++i) {
259 auto op_i = op->operands()[i];
260 VLOG(2) << "Looking for " << op_i->ToString() << "\n";
261 if (ContainsKey(hoisted_instructions, op_i)) {
262 auto new_op_i =
263 FindOrDie(hoisted_instructions, op_i).operands()[dest_index];
264 VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
265 new_operands.push_back(new_op_i);
266 } else {
267 switch (op_i->opcode()) {
268 case HloOpcode::kConstant: {
269 auto new_op_i = parent->AddInstruction(op_i->Clone());
270 VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
271 new_operands.push_back(new_op_i);
272 break;
273 }
274 case HloOpcode::kGetTupleElement: {
275 auto gte = Cast<HloGetTupleElementInstruction>(op_i);
276 int64_t index = gte->tuple_index();
277 HloInstruction* root = parent->root_instruction();
278 CHECK(root->opcode() == HloOpcode::kTuple &&
279 index < root->operand_count());
280 auto new_op_i = root->mutable_operand(index);
281 VLOG(2) << "new instruction:" << new_op_i->ToString() << "\n";
282 new_operands.push_back(new_op_i);
283 break;
284 }
285 default:
286 LOG(FATAL) << "Unexpected out-of-boundary instruction:"
287 << op_i->ToString() << "\n";
288 }
289 }
290 }
291 HloInstruction* new_instruction = parent->AddInstruction(
292 op->CloneWithNewOperands(op->shape(), new_operands));
293 VLOG(2) << "new instruction:" << new_instruction->ToString() << "\n";
294 // Maps the instruction outside of conditional to the instruction
295 // inside of the conditional.
296 for (HloInstruction* op : boundary.operands()) {
297 Boundary b2 = ContainsKey(hoisted_instructions, op)
298 ? hoisted_instructions[op]
299 : Boundary(boundary.IsOutsideBranch()
300 ? Boundary::Position::kInsideBranch
301 : Boundary::Position::kOutsideBranch);
302 b2.mutable_operands().push_back(new_instruction);
303 hoisted_instructions[op] = b2;
304 }
305 return Status::OK();
306 }
307
308 // Identify converts to be hoisted/rematerialized out of the branch
309 // computations.
FindSpecialConverts(HloInstruction * old_root,int branch_count,HloInstruction * conditional,bool is_layout_sensitive)310 absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
311 int branch_count,
312 HloInstruction* conditional,
313 bool is_layout_sensitive) {
314 absl::flat_hash_set<int64> kspecial_convert;
315 for (int64_t operand_num = 0; operand_num < old_root->operand_count();
316 ++operand_num) {
317 if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
318 continue;
319 }
320 bool replica = true;
321 HloInstruction* kspecial_convert_candidate =
322 old_root->mutable_operand(operand_num);
323 // Check whether an identical candidate appears in other branches
324 for (int others = 1; others < branch_count; ++others) {
325 HloInstruction* others_root =
326 conditional->branch_computation(others)->root_instruction();
327 bool eq_shape =
328 is_layout_sensitive
329 ? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
330 kspecial_convert_candidate->shape())
331 : ShapeUtil::Compatible(
332 others_root->operand(operand_num)->shape(),
333 kspecial_convert_candidate->shape());
334 if ((others_root->operand(operand_num)->opcode() ==
335 HloOpcode::kConvert) &&
336 eq_shape) {
337 // Nothing to be done.
338 } else {
339 replica = false;
340 break;
341 }
342 }
343 if (replica) {
344 kspecial_convert.insert(operand_num);
345 }
346 }
347 return kspecial_convert;
348 }
349
350 // Restructuring the conditional instruction as follows:
351 // i.e., %result = conditional() becomes
352 // x = conditional()
353 // y.{0..n} = gte(x, {0..n})
354 // z = tuple(y.0, y.1, ...y.n)
355 // Doing so ensures that we can accommodate the possible shape-change of the
356 // conditional when the instructions are hoisted.
RestructureConditionalInstruction(HloComputation * computation,HloInstruction * conditional)357 Status RestructureConditionalInstruction(HloComputation* computation,
358 HloInstruction* conditional) {
359 HloInstruction* old_root = computation->root_instruction();
360 std::vector<HloInstruction*> new_operands;
361 int cur_index = 0;
362 for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
363 ++cur_index) {
364 new_operands.push_back(
365 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
366 ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
367 conditional, cur_index)));
368 }
369 HloInstruction* new_tuple =
370 computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
371 if (old_root == conditional) {
372 computation->set_root_instruction(new_tuple);
373 } else {
374 std::vector<HloInstruction*> new_tuple_users;
375 for (auto conditional_user : conditional->users()) {
376 auto is_new_gte = absl::c_find_if(
377 new_operands,
378 [&](HloInstruction* instr) { return instr == conditional_user; });
379 if (is_new_gte == new_operands.end()) {
380 new_tuple_users.push_back(conditional_user);
381 }
382 }
383 for (auto new_tuple_user : new_tuple_users) {
384 TF_RETURN_IF_ERROR(
385 conditional->ReplaceUseWith(new_tuple_user, new_tuple));
386 }
387 }
388 VLOG(2) << "computation after root restructure:\n" << computation->ToString();
389 return Status::OK();
390 }
391
ConvertSpecialMove(HloInstruction * conditional,bool is_layout_sensitive)392 StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
393 bool is_layout_sensitive) {
394 int branch_count = conditional->branch_count();
395 if (branch_count <= 0) {
396 return false;
397 }
398
399 // Determining whether all branch roots are tuples
400 for (int branch_num = 0; branch_num < branch_count; ++branch_num) {
401 HloInstruction* branch_root =
402 conditional->branch_computation(branch_num)->root_instruction();
403 if (branch_root->opcode() != HloOpcode::kTuple) {
404 return false;
405 }
406 }
407
408 HloInstruction* old_root =
409 conditional->branch_computation(0)->root_instruction();
410 VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
411 // Identify the gte using `index'.
412 auto find_gte = [](const HloInstruction* conditional_result,
413 int64_t index) -> HloInstruction* {
414 for (HloInstruction* instr : conditional_result->users()) {
415 if (instr->opcode() != HloOpcode::kGetTupleElement) {
416 return nullptr;
417 }
418 if (instr->tuple_index() == index) {
419 return instr;
420 }
421 }
422 return nullptr;
423 };
424
425 // Captures tuple indices refering to converts to be rematerialized/hoisted.
426 absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
427 old_root, branch_count, conditional, is_layout_sensitive);
428
429 // Exit if we cannot find any converts to be hoisted.
430 if (kspecial_convert.empty()) {
431 return false;
432 }
433
434 TF_RETURN_IF_ERROR(
435 RestructureConditionalInstruction(conditional->parent(), conditional));
436
437 for (int branch = 0; branch < branch_count; branch++) {
438 old_root = conditional->branch_computation(branch)->root_instruction();
439 absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
440 std::vector<HloInstruction*> new_operands(old_root->operand_count());
441 absl::flat_hash_set<HloInstruction*> to_hoist_set;
442
443 for (int64_t operand_num = 0; operand_num < old_root->operand_count();
444 ++operand_num) {
445 map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
446 operand_num;
447 }
448 for (int64_t operand_num = 0; operand_num < old_root->operand_count();
449 ++operand_num) {
450 HloInstruction* hoist = old_root->mutable_operand(operand_num);
451 if (!kspecial_convert.contains(operand_num)) {
452 new_operands[operand_num] = old_root->mutable_operand(operand_num);
453 continue;
454 }
455
456 to_hoist_set.insert(hoist);
457 int64_t new_tuple_count = old_root->operand_count();
458
459 // Replace the hoisted instr in the tuple with the operand/operands.
460 // We will replace at least one of the operands of the hoist at the
461 // tuple place; the rest will be added at the end.
462 bool inplace = true;
463 CHECK(!hoist->operands().empty());
464 for (HloInstruction* prod : hoist->operands()) {
465 if (inplace) {
466 map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
467 new_operands[map_inst_to_tuple_index[hoist]] = prod;
468 inplace = false;
469 } else {
470 map_inst_to_tuple_index[prod] = new_tuple_count++;
471 new_operands.push_back(prod);
472 }
473 }
474 }
475
476 // Create the new root instruction.
477 HloComputation* cur_branch = conditional->branch_computation(branch);
478 HloInstruction* new_branch_root =
479 cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
480 // The shape can vary since the operands to convert are now
481 // being returned through the branches' root.
482 cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
483 TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
484
485 // Only one of the branches needs to change the conditional->parent().
486 if (branch != 0) {
487 continue;
488 }
489 HloComputation* conditional_parent = conditional->parent();
490 HloInstruction* newconditional =
491 conditional_parent->AddInstruction(HloInstruction::CreateConditional(
492 cur_branch->root_instruction()->shape(),
493 conditional->mutable_operand(0),
494 absl::MakeSpan(conditional->branch_computations()),
495 absl::MakeSpan(conditional->operands()).subspan(1)));
496 // Ensure that all the users of conditional refer to the new one.
497 TF_RETURN_IF_ERROR(
498 conditional->ReplaceAllUsesWithDifferentShape(newconditional));
499 TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
500 conditional = newconditional;
501 // Add the hoisted instructions in the parent.
502 for (HloInstruction* hoist : to_hoist_set) {
503 VLOG(2) << "Hoisting instruction:" << hoist->ToString();
504 int64_t hoist_index = map_inst_to_tuple_index[hoist];
505 // Find out the gte that captured the hoisted instr result.
506 HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
507 CHECK(gte_hoist != nullptr);
508 std::vector<HloInstruction*> new_operands;
509 for (HloInstruction* op : hoist->operands()) {
510 HloInstruction* gte = conditional_parent->AddInstruction(
511 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
512 map_inst_to_tuple_index[op]));
513 new_operands.push_back(gte);
514 }
515 HloInstruction* hoisted = conditional_parent->AddInstruction(
516 hoist->CloneWithNewOperands(hoist->shape(), new_operands));
517 VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
518 TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
519 TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
520 }
521 // No need to explicitly delete a hoisted instruction since if its dead
522 // then the subsequent DCE will remove it.
523 }
524 VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
525 return true;
526 }
527
528 // Hoist identical ops out of the conditional. The definition of identical
529 // are the shape of the operands are identical and their properties are
530 // identical. Will start from the root instruction of each branch and get
531 // the identical ops to hoist.
MoveInstructionOut(HloInstruction * conditional,std::vector<Boundary> & to_move_out,std::vector<Boundary> & new_boundaries)532 StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut(
533 HloInstruction* conditional, std::vector<Boundary>& to_move_out,
534 std::vector<Boundary>& new_boundaries) {
535 if (to_move_out.empty()) {
536 return false;
537 }
538 VLOG(1) << "Modifying code--number of boundaries to move out:"
539 << to_move_out.size() << "\n";
540 HloComputation* conditional_parent = conditional->parent();
541 // save the old users before add new conditional user instructions
542 std::vector<HloInstruction*> old_conditional_users = conditional->users();
543 // Maps instructions in the conditional body to instructions hoisted outside
544 // the conditional that compute the same value.
545 absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
546 // Insert GetTupleElement before the instructions whose operands might still
547 // be within the conditional.
548 VLOG(1) << "before opt:"
549 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
550 << "\n";
551 int64_t op_index = 0;
552 for (const Boundary& b : new_boundaries) {
553 HloInstruction* op = b.operands()[0];
554 CHECK(op != nullptr);
555 VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
556 HloInstruction* gtr = conditional_parent->AddInstruction(
557 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
558 op_index++));
559 Boundary b2(Boundary::Position::kOutsideBranch);
560 b2.mutable_operands().push_back(gtr);
561 hoisted_instructions[op] = b2;
562 }
563 // Copy boundary instructions out of the conditional.
564 // Visit the operands before its users and copy it, so that the copied
565 // user will point to the correct operand.
566 for (int64_t i = to_move_out.size() - 1; i >= 0; i--) {
567 TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(
568 to_move_out[i], 0, conditional_parent, hoisted_instructions));
569 }
570 VLOG(2) << "Done copy branch instructions out\n"
571 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
572 << "\n";
573 // Change original users of the conditional to use the correct operands.
574 HloInstruction* old_root =
575 conditional->branch_computation(0)->root_instruction();
576 for (auto user_instr : old_conditional_users) {
577 VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n";
578 CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement);
579 auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr);
580 int64_t index = tuple_opd->tuple_index();
581 CHECK(old_root->operands().size() > index);
582 HloInstruction* old_opd = old_root->operands()[index];
583 VLOG(2) << "old opd = " << old_opd << "\n";
584 CHECK(ContainsKey(hoisted_instructions, old_opd));
585 HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0];
586 CHECK(old_opd != nullptr);
587 CHECK(new_opd != nullptr);
588 VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n";
589 TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd));
590 TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr));
591 }
592 VLOG(2) << "Done changing conditional users\n"
593 << conditional_parent->ToString() << "\n";
594 // Create tuple element within each branch and set it as root.
595 int64_t branch_count = conditional->branch_count();
596 for (int i = 0; i < branch_count; i++) {
597 auto computation = conditional->branch_computation(i);
598 std::vector<HloInstruction*> elements;
599 for (const auto& b1 : new_boundaries) {
600 HloInstruction* op = b1.operands()[i];
601 CHECK(op != nullptr);
602 VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n";
603 elements.push_back(op);
604 }
605 HloInstruction* tuple =
606 computation->AddInstruction(HloInstruction::CreateTuple(elements));
607 computation->set_root_instruction(tuple, true);
608 VLOG(2) << "computation is :" << computation->ToString() << "\n";
609 // Remove hoisted instructions from the branches.
610 for (const auto& b2 : to_move_out) {
611 auto instr_to_remove = b2.operands()[i];
612 // Double check to make sure it is safe to delete the instruction.
613 // Complications may arise due to some operations in the alternative
614 // branches (branches 1..n) being placed into the boundaries multiple
615 // times.
616 if (!computation->IsMarkedAsDead(instr_to_remove) &&
617 instr_to_remove->user_count() == 0) {
618 VLOG(2) << "Removing boundary:" << b2.ToString() << "\n";
619 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove));
620 }
621 }
622 }
623 // Change conditional instruction shape to the shape of the new root.
624 HloInstruction* new_root =
625 conditional->branch_computation(0)->root_instruction();
626 *conditional->mutable_shape() = new_root->shape();
627 VLOG(1) << "done moving instructions out of branches\n"
628 << conditional_parent->ToString(HloPrintOptions::Fingerprint())
629 << "\n";
630 return true;
631 }
632
633 // Hoist ops from outside of the conditional to inside the branches.
MoveInstructionIn(HloInstruction * conditional,std::vector<Boundary> & to_move_in,std::vector<Boundary> & new_boundaries)634 StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn(
635 HloInstruction* conditional, std::vector<Boundary>& to_move_in,
636 std::vector<Boundary>& new_boundaries) {
637 if (to_move_in.empty()) {
638 return false;
639 }
640 VLOG(1) << "Modifying code---number of boundaries to move in:"
641 << to_move_in.size() << "\n";
642 VLOG(1) << "before opt:"
643 << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
644 << "\n";
645 // Mapping instructions to be moved to their new representations.
646 absl::flat_hash_map<HloInstruction*, Boundary> hoisted_instructions;
647 int64_t to_move_in_size = to_move_in.size();
648 int64_t branch_count = conditional->branch_count();
649 HloGetTupleElementInstruction* tuple_use =
650 DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]);
651 // If use_index is -1, the old conditional root entry used by to_move_in
652 // instructions still need to be included as an entry of the modified
653 // conditional root, and the new result of the to_move_in instructions
654 // need to be added as an extra entry of the modified root; otherwise, the
655 // old root entry will be replaced with the new result in the modified root.
656 // The entry replacement should be allowed only if tuple_use has <=1 users.
657 int64_t use_index = (tuple_use != nullptr && tuple_use->user_count() == 1)
658 ? tuple_use->tuple_index()
659 : -1;
660 VLOG(2) << "Tuple use index = " << use_index << "\n";
661 // Number of old conditional entries still to be used outside.
662 // If conditional shape is not tuple, will create a tuple and use subscript
663 // 0 to save the old operand being used.
664 int64_t op_index =
665 conditional->shape().IsTuple()
666 ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1
667 : conditional->shape().tuple_shapes_size())
668 : 0;
669 // Use to map the tuple_use instruction to its operand;
670 Boundary b_opd_use(Boundary::Position::kInsideBranch);
671 Boundary b_old_root(Boundary::Position::kInsideBranch);
672 // Create a new root instruction in each branch.
673 for (int i = 0; i < branch_count; i++) {
674 auto computation = conditional->branch_computation(i);
675 auto old_root = computation->root_instruction();
676 b_old_root.mutable_operands().push_back(old_root);
677 std::vector<HloInstruction*> operands;
678 if (old_root->opcode() == HloOpcode::kTuple) {
679 // Use operands of old_root directly, so old_root can be removed later.
680 for (int i = 0; i < old_root->operand_count(); ++i) {
681 if (i != use_index) {
682 operands.push_back(old_root->operands()[i]);
683 } else { // Map conditional use to the tuple operand.
684 b_opd_use.mutable_operands().push_back(old_root->operands()[i]);
685 }
686 }
687 } else if (old_root->shape().IsTuple()) {
688 // If old_root is not a kTuple but has tuple shape, elements within the
689 // tuple must be extracted first to be used by the new instructions.
690 const Shape& old_shape = old_root->shape();
691 for (int64_t i = 0; i < old_shape.tuple_shapes_size(); ++i) {
692 auto element =
693 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
694 old_shape.tuple_shapes(i), old_root, i));
695 if (i != use_index) {
696 operands.push_back(element);
697 } else {
698 b_opd_use.mutable_operands().push_back(element);
699 }
700 }
701 } else {
702 // If old_root is not a tuple and does not have tuple shape, use it
703 // to replace the conditional directly in the new computation.
704 b_opd_use.mutable_operands().push_back(conditional);
705 }
706
707 HloInstruction* new_root =
708 computation->AddInstruction(HloInstruction::CreateTuple(operands));
709 VLOG(2) << "setting new root: " << new_root->ToString() << "\n";
710 computation->set_root_instruction(new_root,
711 /*accept_different_shape*/ true);
712 if (old_root->opcode() == HloOpcode::kTuple) {
713 TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root));
714 }
715 VLOG(2) << "new branch computation: " << computation->ToString() << "\n";
716 }
717 // Update get tuple element index of the conditional.
718 if (use_index != -1) {
719 for (auto* user : conditional->users()) {
720 if (user->opcode() == HloOpcode::kGetTupleElement &&
721 user->tuple_index() > use_index) {
722 user->set_tuple_index(user->tuple_index() - 1);
723 }
724 }
725 }
726 hoisted_instructions[conditional] = b_old_root;
727 int64_t cp_start = 0;
728 if (use_index >= 0) {
729 VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n";
730 hoisted_instructions[tuple_use] = b_opd_use;
731 }
732 cp_start = (tuple_use != nullptr) ? 1 : 0;
733 for (int64_t to_move_index = cp_start; to_move_index < to_move_in_size;
734 to_move_index++) {
735 Boundary b_to_move = to_move_in[to_move_index];
736 HloInstruction* op = b_to_move.operands()[0];
737 CHECK(op != nullptr);
738 bool to_be_used_outside = true;
739 VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n";
740 if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 &&
741 op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) {
742 to_be_used_outside = false;
743 VLOG(2) << "Instruction is not to be used outside the branch\n";
744 }
745 Boundary b(Boundary::Position::kInsideBranch);
746 for (int i = 0; i < branch_count; i++) {
747 auto computation = conditional->branch_computation(i);
748 VLOG(2) << "Copying to branch: " << i << "\n";
749 TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation,
750 hoisted_instructions));
751 VLOG(2) << "Done:" << computation->ToString() << "\n";
752 if (to_be_used_outside) {
753 auto new_op = hoisted_instructions[op].operands()[i];
754 auto new_root = computation->root_instruction();
755 new_root->AppendOperand(new_op);
756 *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape();
757 VLOG(2) << "Extending conditional root " << i << " : "
758 << new_root->ToString() << "\n";
759 }
760 VLOG(2) << "After extending branch root: " << computation->ToString()
761 << "\n";
762 }
763 if (to_be_used_outside) {
764 // Modify uses of instructions outside of the conditionals
765 HloInstruction* gtr = conditional->parent()->AddInstruction(
766 HloInstruction::CreateGetTupleElement(op->shape(), conditional,
767 op_index++));
768 TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr));
769 if (conditional->parent()->root_instruction() == op) {
770 conditional->parent()->set_root_instruction(gtr);
771 }
772 }
773 }
774 VLOG(2) << "Done copying instructions inside branch: "
775 << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n";
776 // Change conditional instruction shape to the shape of the new root.
777 HloInstruction* new_root =
778 conditional->branch_computation(0)->root_instruction();
779 *conditional->mutable_shape() = new_root->shape();
780 VLOG(2) << "Before removing instructions:"
781 << conditional->parent()->ToString() << "\n";
782 // Remove hoisted instructions from the branches.
783 for (int64_t i = to_move_in_size - 1; i >= 0; i--) {
784 Boundary boundary_to_move_in = to_move_in[i];
785 HloInstruction* op = boundary_to_move_in.operands()[0];
786 if (op->user_count() == 0) {
787 VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n";
788 TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op));
789 VLOG(2) << "Done removing boundary.\n";
790 }
791 }
792
793 // Reset shapes of user gtes to the new shape.
794 if (use_index != -1) {
795 for (auto* user : conditional->users()) {
796 if (user->opcode() == HloOpcode::kGetTupleElement) {
797 VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n";
798 *user->mutable_shape() =
799 conditional->shape().tuple_shapes(user->tuple_index());
800 }
801 }
802 }
803 VLOG(1) << "Done moving instructions inside branches\n"
804 << conditional->parent()->ToString(HloPrintOptions::Fingerprint())
805 << "\n";
806 return true;
807 }
808
809 // Group single chains of operands or uses of boundaries into new boundaries
810 class GroupConnectedBoundaries {
811 private:
812 std::vector<Boundary> connected_boundaries_, new_boundaries_;
813 HloInstruction* conditional_;
814 HloComputation* conditional_parent_;
815 bool is_layout_sensitive_;
816 // Instructions that have been visited but are not going to be moved.
817 absl::flat_hash_map<HloInstruction*, int>& visited_count_;
818 // The following four lines are configurations of the cost model, which will
819 // be used to determine whether to move an instruction (move_config_) and how
820 // strongly preferred it is to keep a pair of ops together (reuse_config_).
821 // The search_config_ is used to control how to navigate the search space of
822 // the cost model in the context of auto/manual tuning. The flipped array is
823 // used to save which entries in the configuration have been changed in the
824 // search/tuning process.
825 std::vector<std::vector<int64>>& move_config_;
826 std::vector<std::vector<int64>>& reuse_config_;
827 std::vector<int64>& search_config_vec_;
828 int64* search_config_;
829 int64 search_subscript_;
830 absl::flat_hash_map<const int64*, int64> flipped_;
831
832 // The FlipMutation function serves to implement the search of alternative
833 // cost models by deciding whether to flip a given configuration, saved in
834 // the loc parameter. The non_zero parameter provides the new value to use
835 // to flip a zero. The msg parameter is only used for debugging purpposes.
FlipMutation(int64 * loc,const int64_t non_zero,const std::string & msg)836 int64 FlipMutation(int64* loc, const int64_t non_zero,
837 const std::string& msg) {
838 if (search_config_ == 0 || ContainsKey(flipped_, loc)) {
839 VLOG(2) << "Configured not to search or loc is already flipped.";
840 return *loc;
841 }
842 // The last 8 digits control when to start the first flip.
843 int c = ConditionalCodeMotion::flip_start(*search_config_);
844 VLOG(2) << "flip start index = " << c << "\n";
845 // Only flip the decision if c reaches 0.
846 if (c > 0) {
847 (*search_config_)--;
848 return *loc;
849 }
850 // The 8-16 digits control the maximum number of times to flip a config.
851 auto flip_count = ConditionalCodeMotion::DecrementMaxFlip(search_config_);
852 VLOG(2) << "max flip count = " << flip_count << "\n";
853 VLOG(2) << "Updating max Flipping configuration = " << *search_config_
854 << "\n";
855 if (flip_count == 0) {
856 VLOG(2) << "Maximum flip count has reached. ";
857 if (search_subscript_ + 1 < search_config_vec_.size()) {
858 VLOG(2) << "search_subscript_ = " << search_subscript_;
859 VLOG(2) << "search config vec size = " << search_config_vec_.size();
860 search_config_ = &search_config_vec_[++search_subscript_];
861 } else {
862 return *loc;
863 }
864 }
865 // Reload the 16-23 digits of the configuration, which controls how
866 // frequently a configuration should be flipped.
867 auto flip_stride = ConditionalCodeMotion::flip_stride(*search_config_);
868 *search_config_ += flip_stride;
869 VLOG(2) << "flip stride = " << flip_stride << "\n";
870 VLOG(2) << "Updating Flipping Stride = " << *search_config_ << "\n";
871
872 flipped_[loc] = *loc;
873 // Copy the last 8 bits back to the first 8 bits of configuration.
874 switch (*loc) {
875 case 0:
876 *loc = non_zero;
877 break;
878 default:
879 *loc = 0;
880 break;
881 }
882 VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc]
883 << " to " << *loc << "\n";
884 return *loc;
885 }
886
887 public:
GroupConnectedBoundaries(HloInstruction * conditional,bool is_layout_sensitive,absl::flat_hash_map<HloInstruction *,int> & visited_count,std::vector<std::vector<int64>> * move_config,std::vector<std::vector<int64>> * reuse_config,std::vector<int64> * search_config)888 explicit GroupConnectedBoundaries(
889 HloInstruction* conditional, bool is_layout_sensitive,
890 absl::flat_hash_map<HloInstruction*, int>& visited_count,
891 std::vector<std::vector<int64>>* move_config,
892 std::vector<std::vector<int64>>* reuse_config,
893 std::vector<int64>* search_config)
894 : conditional_(conditional),
895 conditional_parent_(conditional->parent()),
896 is_layout_sensitive_(is_layout_sensitive),
897 visited_count_(visited_count),
898 move_config_(*move_config),
899 reuse_config_(*reuse_config),
900 search_config_vec_(*search_config),
901 search_subscript_(0) {
902 VLOG(2) << "Initializing Group Connected Boundaries\n";
903 CHECK_NE(search_config, nullptr);
904 if (search_config_vec_.empty()) {
905 search_config_vec_.push_back(0);
906 }
907 search_config_ = &search_config_vec_[0];
908 }
909 // Returns estimation of potential reuses carried by a given pair of
910 // instructions. Use different integers to classify different levels
911 // of reuses. Assume all instructions can be fused to enable data reuses.
ReusesCarriedBy(HloInstruction * op,HloInstruction * user)912 int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) {
913 std::vector<int64>& curconfig =
914 reuse_config_[static_cast<uint32>(op->opcode())];
915 // Flip the reuse configuration if tuning the cost model.
916 // When flipping, use -10 if flipping to the default reuse model. Other
917 // values can be specified if needed to fine-control the decision making.
918 int64_t config =
919 ((*search_config_) < 0)
920 ? FlipMutation(&curconfig[static_cast<uint32>(user->opcode())], -10,
921 HloOpcodeString(op->opcode()) + "->" +
922 HloOpcodeString(user->opcode()))
923 : curconfig[static_cast<uint32>(user->opcode())];
924 VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: "
925 << op->ToString() << "=>" << user->ToString() << " : " << config
926 << "\n";
927 if (config < 0) {
928 // Assume the reuse decreases with increasing user count.
929 int count1 = CountNonLeafOps(op->users());
930 int count2 = CountNonLeafOps(user->operands());
931 return (-config) / count1 / count2;
932 }
933 return config;
934 }
clear_recently_visited()935 void clear_recently_visited() {
936 for (const auto& boundary : new_boundaries_) {
937 visited_count_.erase(boundary.operands()[0]);
938 }
939 }
940 // Returns true if `instruction` is worth hoisting.
WorthHoisting(HloInstruction * instruction,bool is_inside_branch)941 bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) {
942 // This is needed for the "moving-in" transformation, to prevent the root
943 // of the parent computation (which contains the conditional) to be moved
944 // inside the conditional.
945 HloOpcode opcode = instruction->opcode();
946 if (opcode == HloOpcode::kTuple &&
947 instruction == conditional_parent_->root_instruction()) {
948 return false;
949 }
950 // It is not safe to move collective ops from outside to inside
951 // conditional branches, as it may cause synchronization problems,
952 // when different layouts are assigned to different branches.
953 if (DynCast<HloCollectiveInstruction>(instruction) && !is_inside_branch) {
954 return false;
955 }
956
957 // It is not legal to move the parameter instructions.
958 if (opcode == HloOpcode::kParameter) {
959 return false;
960 }
961
962 // Use configuration given from outside (e.g., by autotuner).
963 std::vector<int64>& curconfig = move_config_[static_cast<uint32>(opcode)];
964 auto col = (curconfig.size() == 1) ? 0
965 : (instruction->operand_count() > 0)
966 ? static_cast<uint32>(instruction->operand(0)->opcode())
967 : 0;
968 VLOG(2) << "column = " << col << "\n";
969 VLOG(2) << "config size = " << curconfig.size() << "\n";
970 VLOG(2) << "search_config = " << *search_config_ << "\n";
971 CHECK(col < curconfig.size());
972 uint32 config = ((*search_config_) > 0)
973 ? FlipMutation(&curconfig[col], 1,
974 "Move-" + HloOpcodeString(opcode))
975 : curconfig[col];
976 VLOG(2) << "Checking instruction is worth moving: " << config << "\n";
977 VLOG(2) << "after checking search_config = " << *search_config_ << "\n";
978 return (config != 0);
979 }
980
ReusesBeforeBoundary(HloInstruction * user)981 int64 ReusesBeforeBoundary(HloInstruction* user) {
982 int64_t reuses = 0;
983 for (auto op : user->operands()) {
984 // The operand must be an instruction that is not going to be moved (if
985 // user is inside the conditional); otherwise it must be the conditional
986 // itself and its user must be outside of the conditional.
987 if (!ContainsKey(visited_count_, op) && op != conditional_) {
988 continue;
989 }
990 if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) {
991 if (op->opcode() == HloOpcode::kConditional) {
992 auto tuple = op->branch_computation(0)->root_instruction();
993 if (tuple->opcode() == HloOpcode::kTuple) {
994 auto index = tuple_gte->tuple_index();
995 CHECK(index < tuple->operand_count());
996 op = tuple->mutable_operand(index);
997 }
998 }
999 reuses += ReusesCarriedBy(op, user->users()[0]);
1000 } else {
1001 reuses += ReusesCarriedBy(op, user);
1002 }
1003 }
1004 VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses
1005 << "\n";
1006 return reuses;
1007 }
1008
ReusesAfterBoundary(HloInstruction * user)1009 int64 ReusesAfterBoundary(HloInstruction* user) {
1010 CHECK(user != nullptr);
1011 auto all_users = user->users();
1012 // For now, assume that if an instruction has multiple-consumers, it
1013 // will not be reused, as the reuse may require duplication in
1014 // fusion and so is expensive. If the situation changes in the future,
1015 // some aspects of the overall algorithm need to be redesigned to
1016 // accommandate the change.
1017 if (all_users.size() > 1) {
1018 VLOG(2) << "Having multiple users from: " << user->ToString() << "\n";
1019 return 0;
1020 }
1021 if (!all_users.empty()) {
1022 auto op = all_users[0];
1023 int64_t reuses = 0;
1024 // Only count reuses that run through the conditional root.
1025 if (op == conditional_->branch_computation(0)->root_instruction()) {
1026 int64_t index = op->operand_index(user);
1027 for (auto op2 : conditional_->users()) {
1028 // If the use is not get tuple, right now do not consider it.
1029 if (op2->opcode() == HloOpcode::kGetTupleElement) {
1030 auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2);
1031 if (index == tuple_opd->tuple_index()) {
1032 all_users = op2->users();
1033 if (!all_users.empty()) {
1034 reuses += ReusesCarriedBy(user, all_users[0]);
1035 break;
1036 }
1037 }
1038 }
1039 }
1040 } else if (ContainsKey(visited_count_, op)) {
1041 reuses += ReusesCarriedBy(user, op);
1042 }
1043 VLOG(2) << "reuses after instruction " << user->ToString() << ":"
1044 << reuses << "\n";
1045 return reuses;
1046 }
1047 return 0;
1048 }
1049
BenefitForMovingBoundaries(const std::vector<Boundary> & boundaries,bool perform_reuse_analysis=true)1050 int64 BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries,
1051 bool perform_reuse_analysis = true) {
1052 int64_t reuses_before = 0, reuses_after = 0;
1053 if (boundaries.size() == 1) {
1054 if (boundaries[0].IsOutsideBranch() &&
1055 boundaries[0].operands()[0]->opcode() ==
1056 HloOpcode::kGetTupleElement) {
1057 // The only boundary of moving-in is the get_tuple_element op.
1058 return -1;
1059 }
1060 if (boundaries[0].IsInsideBranch() &&
1061 boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) {
1062 // The only boundary of moving-out is the tuple op inside branches.
1063 return -1;
1064 }
1065 }
1066 // If trying alternative moving configurations, turn off reuse analysis.
1067 if (!perform_reuse_analysis) {
1068 return 1;
1069 }
1070 // For cases like :
1071 // branch0 {
1072 // ROOT copy
1073 // }
1074 // branch1 {
1075 // ...
1076 // }
1077 // cond = conditional(branch0, branch1)
1078 // copy = copy(cond)
1079 //
1080 // We can fold the two copies thus reducing computation.
1081 auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64 {
1082 if (hlo->opcode() != HloOpcode::kCopy) {
1083 return 0;
1084 }
1085 const HloGetTupleElementInstruction* gte =
1086 DynCast<HloGetTupleElementInstruction>(hlo->operand(0));
1087 if (gte == nullptr) {
1088 return 0;
1089 }
1090 const HloInstruction* conditional = gte->operand(0);
1091 if (conditional != conditional_) {
1092 return 0;
1093 }
1094 int64_t benefit = 0;
1095 for (auto* branch : conditional->called_computations()) {
1096 HloInstruction* root = branch->root_instruction();
1097 if (root->opcode() == HloOpcode::kTuple) {
1098 const auto* tuple_operand = root->operand(gte->tuple_index());
1099 if (tuple_operand->opcode() == HloOpcode::kCopy) {
1100 if (Shape::Equal()(tuple_operand->operand(0)->shape(),
1101 hlo->shape())) {
1102 benefit += 10;
1103 }
1104 }
1105 }
1106 }
1107 return benefit;
1108 };
1109 for (const Boundary& b : boundaries) {
1110 auto op = b.operands()[0];
1111 if (op == conditional_->branch_computation(0)->root_instruction()) {
1112 continue;
1113 }
1114 VLOG(2) << "Benefit for " << op->ToString();
1115 reuses_before += ReusesBeforeBoundary(op);
1116 VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n";
1117 reuses_after += ReusesAfterBoundary(op);
1118 VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n";
1119 }
1120
1121 int64_t copy_folding_benefit = 0;
1122 if (boundaries[0].IsOutsideBranch()) {
1123 for (const Boundary& b : boundaries) {
1124 auto op = b.operands()[0];
1125 copy_folding_benefit += get_copy_folding_benefit(op);
1126 }
1127 }
1128 VLOG(2) << "Copy folding benefit: " << copy_folding_benefit;
1129
1130 if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) {
1131 return -1;
1132 } else if (boundaries[0].IsInsideBranch()) {
1133 return reuses_after - reuses_before;
1134 } else {
1135 return reuses_before - reuses_after - 1 + copy_folding_benefit;
1136 }
1137 }
1138
GetNextBoundary(const Boundary & b,int64_t op_index)1139 Boundary GetNextBoundary(const Boundary& b, int64_t op_index) {
1140 Boundary b2(b.GetPosition());
1141 for (int j = 0; j < b.operands().size(); ++j) {
1142 HloInstruction* inst = b.operands()[j];
1143 CHECK(inst != nullptr);
1144 HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index]
1145 : inst->users()[op_index];
1146 CHECK(op != nullptr);
1147 b2.mutable_operands().push_back(op);
1148 }
1149 return b2;
1150 }
1151
1152 // Checking whether it is safe to move a boundary when visited through a
1153 // dependent already considered for moving.
IsSafeToMoveBoundary(const Boundary & next_boundary)1154 bool IsSafeToMoveBoundary(const Boundary& next_boundary) {
1155 int64_t next_boundary_count =
1156 (next_boundary.IsInsideBranch())
1157 ? next_boundary.operands()[0]->user_count()
1158 : CountNonLeafOps(next_boundary.operands()[0]->operands());
1159 if (next_boundary_count <= 1) {
1160 // If boundary has only a single or no dependent, safe to move.
1161 return true;
1162 } else {
1163 if (!ContainsKey(visited_count_, next_boundary.operands()[0])) {
1164 VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n"
1165 << " because it has multiple dependents: "
1166 << next_boundary_count << "\n";
1167 visited_count_[next_boundary.operands()[0]] = 1;
1168 new_boundaries_.push_back(next_boundary);
1169 } else {
1170 auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(),
1171 next_boundary);
1172 if (pos != new_boundaries_.end() ||
1173 next_boundary.operands().size() == 1) {
1174 int count = ++visited_count_[next_boundary.operands()[0]];
1175 if (count == next_boundary_count) {
1176 VLOG(2) << "Recovering next boundary " << next_boundary.ToString()
1177 << "\n"
1178 << " because all of its dependents have been visited: "
1179 << next_boundary_count << "\n";
1180 visited_count_.erase(next_boundary.operands()[0]);
1181 if (pos != new_boundaries_.end()) {
1182 new_boundaries_.erase(pos);
1183 }
1184 return true;
1185 }
1186 } else {
1187 VLOG(2) << "Skip incompatible multi-dependent boundary: "
1188 << next_boundary.ToString() << ":" << next_boundary_count
1189 << "\n";
1190 }
1191 }
1192 }
1193 return false;
1194 }
1195 // This function is reused both for moving the boundary outside or into a
1196 // conditional. As the result, the readability is somewhat compromised.
1197 // It might be nice to refactor this function to factor the outside-inside
1198 // considerations into separate function pointer parameters to improve
1199 // readability.
AddBoundaries(const Boundary & boundary)1200 void AddBoundaries(const Boundary& boundary) {
1201 BoundaryVisitor visitor;
1202 visitor.AddToWorkList(boundary);
1203 while (visitor.HasNextBoundary()) {
1204 Boundary b = visitor.PopNextBoundary();
1205 VLOG(2) << "visiting boundary " << b.ToString() << "\n";
1206 if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical(
1207 b.operands(), is_layout_sensitive_)) &&
1208 IsSafeToMoveBoundary(b) &&
1209 WorthHoisting(b.operands()[0], b.IsInsideBranch())) {
1210 connected_boundaries_.push_back(b);
1211 VLOG(2) << "boundary can be moved\n";
1212 int64_t operand_count = (b.IsInsideBranch())
1213 ? b.operands()[0]->operand_count()
1214 : b.operands()[0]->users().size();
1215 for (int i = 0; i < operand_count; i++) {
1216 Boundary next_boundary = GetNextBoundary(b, i);
1217 VLOG(2) << "Add operand/user " << i << " to visit later\n";
1218 visitor.AddToWorkList(next_boundary);
1219 }
1220 } else {
1221 VLOG(2) << "boundary cannot be moved\n";
1222 visited_count_[b.operands()[0]] = 1;
1223 new_boundaries_.push_back(b);
1224 }
1225 }
1226 }
BoundariesToMoveInOrOut(HloInstruction * conditional,const Boundary & b)1227 std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
1228 const Boundary& b) {
1229 // At the beginning of optimization, a conditional itself is added to a
1230 // worklist. Here the conditional is expanded into two sets of boundaries:
1231 // the first set contains the boundary that is inside branches and
1232 // contains the root of all branches; the second set of boundaries
1233 // contains all the users of the conditional.
1234 HloInstruction* inst = b.operands()[0];
1235 if (inst == conditional) {
1236 int branch_count = inst->branch_count();
1237 // Add conditional roots as a new boundary to visit.
1238 Boundary boundary_in(Boundary::Position::kInsideBranch);
1239 for (int i = 0; i < branch_count; i++) {
1240 HloComputation* branch_computation = inst->branch_computation(i);
1241 HloInstruction* root_inst = branch_computation->root_instruction();
1242 CHECK(root_inst != nullptr);
1243 boundary_in.mutable_operands().push_back(root_inst);
1244 }
1245 new_boundaries_.push_back(boundary_in);
1246 // Add conditional users as new boundaries to visit.
1247 for (auto u : inst->users()) {
1248 Boundary boundary_in(Boundary::Position::kOutsideBranch);
1249 boundary_in.mutable_operands().push_back(u);
1250 new_boundaries_.push_back(boundary_in);
1251 }
1252 } else {
1253 AddBoundaries(b);
1254 }
1255 return connected_boundaries_;
1256 }
AddNewBoundaries(std::vector<Boundary> & b)1257 void AddNewBoundaries(std::vector<Boundary>& b) {
1258 b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end());
1259 }
1260 };
1261
ConsiderCodeMotion(HloInstruction * conditional,const Boundary & cur_boundary,std::vector<Boundary> & to_move,std::vector<Boundary> & new_boundaries,absl::flat_hash_map<HloInstruction *,int> & visited_count)1262 ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
1263 HloInstruction* conditional, const Boundary& cur_boundary,
1264 std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
1265 absl::flat_hash_map<HloInstruction*, int>& visited_count) {
1266 GroupConnectedBoundaries connect(conditional, is_layout_sensitive_,
1267 visited_count, &move_config_, &reuse_config_,
1268 &search_config_);
1269 auto move_in_or_out =
1270 connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
1271 if (!move_in_or_out.empty()) {
1272 auto benefit = connect.BenefitForMovingBoundaries(
1273 move_in_or_out, search_config_map_.empty());
1274 VLOG(2) << "benefit of moving in or out "
1275 << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n";
1276 if (benefit >= 0) {
1277 new_boundaries.clear();
1278 connect.AddNewBoundaries(new_boundaries);
1279 // The whole sequence in move_in_or_out is either all moving into a
1280 // conditional, or all moving out of a conditional. So looking only
1281 // at the first entry of the sequence is sufficient to know which
1282 // direction the move is intended.
1283 to_move = move_in_or_out;
1284 return Decision(to_move[0].IsInsideBranch()
1285 ? Decision::Direction::kMoveOutOfBranch
1286 : Decision::Direction::kMoveIntoBranch,
1287 benefit);
1288 } else {
1289 connect.clear_recently_visited();
1290 }
1291 } else {
1292 connect.AddNewBoundaries(new_boundaries);
1293 }
1294 return Decision(Decision::Direction::kNoChange, 0);
1295 }
1296
Run(HloModule * module)1297 StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
1298 VLOG(2) << "Begin a new pass of conditional code motion optimization.\n";
1299 // Use to support debugging of optimization, by disabling the opt after it has
1300 // been applied a pre-determined times (to isolate impact of transformations).
1301 if (!ConsumeFuel("conditional_code_motion", [&] {
1302 return "Skipping conditional opt after allowed limit reaching 0.\n";
1303 })) {
1304 return false;
1305 }
1306 bool changed = false;
1307 bool cleanup_changed = false;
1308 {
1309 HloPassPipeline subpipeline("before_conditional_code_motion");
1310 subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_);
1311 subpipeline.AddPass<HloDCE>();
1312 TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1313 cleanup_changed |= cleanup_changed_now;
1314 }
1315 // Gather all the conditional ops in the module ahead of time, to avoid
1316 // potential complications of modifying the code that affecting traversal.
1317 std::vector<HloInstruction*> conditional_ops;
1318 // Track how many times each branch computation is shared.
1319 absl::flat_hash_map<HloComputation*, int> conditional_computations;
1320 for (auto* comp : module->MakeComputationPostOrder()) {
1321 for (auto* instr : comp->MakeInstructionPostOrder()) {
1322 if (instr->opcode() == HloOpcode::kConditional) {
1323 int branch_count = instr->branch_count();
1324 for (int i = 0; i < branch_count; ++i) {
1325 HloComputation* branch_i = instr->branch_computation(i);
1326 if (ContainsKey(conditional_computations, branch_i)) {
1327 conditional_computations[branch_i]++;
1328 } else {
1329 conditional_computations[branch_i] = 0;
1330 }
1331 }
1332 if (instr->shape().IsTuple()) {
1333 bool can_change_tuple_shape = true;
1334 for (auto user : instr->users()) {
1335 VLOG(2) << "user is : " << user->ToString() << "\n";
1336 if (user->opcode() != HloOpcode::kGetTupleElement) {
1337 can_change_tuple_shape = false;
1338 }
1339 }
1340 if (can_change_tuple_shape) {
1341 conditional_ops.push_back(instr);
1342 }
1343 } else {
1344 conditional_ops.push_back(instr);
1345 }
1346 }
1347 }
1348 }
1349
1350 int64_t conditional_index = 0;
1351 // Use to collect mappings between cloned instructions.
1352 HloCloneContext clone_context(module);
1353 for (HloInstruction* conditional : conditional_ops) {
1354 if (conditional_index == 0 || !search_config_map_.empty()) {
1355 auto config_entry = search_config_map_.find(conditional_index);
1356 if (config_entry != search_config_map_.end()) {
1357 search_config_ = (*config_entry).second;
1358 VLOG(2) << "config entry value extracted:" << search_config_.size();
1359 search_config_index_ = 0;
1360 }
1361 VLOG(2) << "Obtaining default configuration for conditional "
1362 << conditional_index << "\n";
1363 SetDefaultMoveConfig();
1364 VLOG(2) << "Done obtaining default configuration\n";
1365 conditional_index++;
1366 }
1367 int branch_count = conditional->branch_count();
1368 // check for shared conditional computations
1369 bool conditional_is_shared = false;
1370 for (int i = 0; i < branch_count; ++i) {
1371 HloComputation* branch_i = conditional->branch_computation(i);
1372 if (conditional_computations[branch_i] > 0) {
1373 conditional_is_shared = true;
1374 break;
1375 }
1376 }
1377
1378 // Boundaries to move out or to move into the branches.
1379 std::vector<std::vector<Boundary>> to_move_out, to_move_in;
1380 std::vector<std::vector<Boundary>> new_boundaries_for_moveout;
1381 std::vector<std::vector<Boundary>> new_boundaries_for_movein;
1382 // Number of times each instruction has been visited for moving.
1383 absl::flat_hash_map<HloInstruction*, int> visited_count;
1384 int benefit_move_out = 0, benefit_move_in = 0;
1385 Decision::Direction final_d = Decision::Direction::kNoChange;
1386 // The conditional is moved into a worklist as the seed (starting point).
1387 // The conditional will be expanded into multiple seeds (starting points),
1388 // its roots and its users, when it is visited by GroupConnectedBoundaries.
1389 // A NO_CHANGE decision will always be returned for the conditional itself,
1390 // so that the other seeding boundaries can be visited in turn.
1391 BoundaryVisitor visitor(conditional);
1392 VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n";
1393 // Try visit all the boundaries, collect the analysis results, and save
1394 // all the benefitical non-conflicting decisions. If two decisions conflict
1395 // with each other, save the more benefitical one.
1396 while (visitor.HasNextBoundary()) {
1397 std::vector<Boundary> to_move, next_boundary;
1398 Boundary boundary = visitor.PopNextBoundary();
1399 VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n";
1400 auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary,
1401 visited_count);
1402 switch (d.GetDirection()) {
1403 case Decision::Direction::kMoveOutOfBranch:
1404 VLOG(2) << "Local Decision is move out of branch\n";
1405 to_move_out.push_back(to_move);
1406 new_boundaries_for_moveout.push_back(next_boundary);
1407 benefit_move_out += d.GetBenefit();
1408 if (benefit_move_out >= benefit_move_in) {
1409 final_d = Decision::Direction::kMoveOutOfBranch;
1410 VLOG(2) << "Current Decision is move out of branch ("
1411 << to_move_out.size() << ")\n";
1412 } else {
1413 VLOG(2) << "Current Decision remains move into branch\n";
1414 }
1415 break;
1416 case Decision::Direction::kMoveIntoBranch:
1417 VLOG(2) << "Decision is move into branch\n";
1418 to_move_in.push_back(to_move);
1419 new_boundaries_for_movein.push_back(next_boundary);
1420 benefit_move_in += d.GetBenefit();
1421 if (benefit_move_out >= benefit_move_in) {
1422 VLOG(2) << "Current Decision remains move out of branch\n";
1423 } else {
1424 final_d = Decision::Direction::kMoveIntoBranch;
1425 VLOG(2) << "Current Decision is move into branch ("
1426 << to_move_in.size() << ")\n";
1427 }
1428 break;
1429 case Decision::Direction::kNoChange:
1430 VLOG(2) << "Decision is no change\n";
1431 for (const Boundary& b : next_boundary) {
1432 visitor.AddToWorkList(b);
1433 VLOG(2) << "Adding new boundary to worklist:" << b.ToString()
1434 << "\n";
1435 }
1436 break;
1437 }
1438 }
1439 // If modification is to be made, need to clone the shared branches.
1440 if (final_d != Decision::Direction::kNoChange && conditional_is_shared) {
1441 for (int i = 0; i < branch_count; ++i) {
1442 HloComputation* branch_i = conditional->branch_computation(i);
1443 if (conditional_computations[branch_i] > 0) {
1444 // Cloning is absolutely needed if the computation is shared by
1445 // different branches, but the cloning can be potentially avoided
1446 // if the sharing is only among branches of the same conditional.
1447 // If cloning these branches causes a problem due to space issues,
1448 // a fix can pass a vector of unique branches to the actual
1449 // transformations, as an alternative representation of the
1450 // conditional branches to be modified. Right now we assume the
1451 // overhead of cloning is minimal since later stages of the compiler
1452 // inline all the computations anyway.
1453 HloComputation* clone_i =
1454 conditional->parent()->parent()->AddEmbeddedComputation(
1455 branch_i->Clone("clone", &clone_context));
1456 conditional->set_branch_computation(i, clone_i);
1457 conditional_computations[branch_i]--;
1458 // Need to translate the analysis result to generate correct result.
1459 auto update_boundary = [&](Boundary& boundary) {
1460 auto cloned_instr =
1461 clone_context.FindInstruction(boundary.operands()[i]);
1462 CHECK(cloned_instr != nullptr);
1463 VLOG(2) << "boundary before cloning:" << boundary.operands()[i]
1464 << "\n";
1465 boundary.mutable_operands()[i] = cloned_instr;
1466 VLOG(2) << "boundary after cloning:" << boundary.operands()[i]
1467 << "\n";
1468 };
1469 // Only boundaries to move out need to be updated.
1470 if (final_d == Decision::Direction::kMoveOutOfBranch) {
1471 for (int i = 0; i < to_move_out.size(); ++i) {
1472 std::vector<Boundary>& m = to_move_out[i];
1473 std::for_each(m.begin(), m.end(), update_boundary);
1474 }
1475 for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) {
1476 std::vector<Boundary>& m = new_boundaries_for_moveout[i];
1477 std::for_each(m.begin(), m.end(), update_boundary);
1478 }
1479 }
1480 }
1481 }
1482 VLOG(2) << "Cloned branches as needed: " << conditional->ToString()
1483 << "\n";
1484 }
1485 // At most one of to_move_out or to_move_in can be non-empty, since there is
1486 // only one optimization decision.
1487 if (final_d == Decision::Direction::kMoveOutOfBranch) {
1488 CHECK(to_move_out.size() == new_boundaries_for_moveout.size());
1489 for (int i = 0; i < to_move_out.size(); ++i) {
1490 TF_ASSIGN_OR_RETURN(bool result,
1491 MoveInstructionOut(conditional, to_move_out[i],
1492 new_boundaries_for_moveout[i]));
1493 changed |= result;
1494 }
1495 VLOG(2) << "Done moving out of branches " << to_move_out.size()
1496 << " times. \n";
1497 if (!ConsumeFuel("conditional_code_motion", [&] {
1498 return "Skipping conditional opt after allowed limit reaching 0.\n";
1499 })) {
1500 break;
1501 }
1502 } else if (final_d == Decision::Direction::kMoveIntoBranch) {
1503 CHECK(to_move_in.size() == new_boundaries_for_movein.size());
1504 for (int i = 0; i < to_move_in.size(); ++i) {
1505 TF_ASSIGN_OR_RETURN(bool result,
1506 MoveInstructionIn(conditional, to_move_in[i],
1507 new_boundaries_for_movein[i]));
1508 changed |= result;
1509 }
1510 VLOG(2) << "Done moving into branches " << to_move_in.size()
1511 << " times. \n";
1512 if (!ConsumeFuel("conditional_code_motion", [&] {
1513 return "Skipping conditional opt after allowed limit reaching 0.\n";
1514 })) {
1515 break;
1516 }
1517 } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) {
1518 // Invoke special handling for convert rematerialization/hoisting
1519 // We need to make sure no sharing is present in the branches because no
1520 // cloning has been done by the earlier analysis.
1521 // TOOD[b/165848866]: extend solution to handle cloning for special move.
1522 TF_ASSIGN_OR_RETURN(
1523 bool convert_result,
1524 ConvertSpecialMove(conditional, is_layout_sensitive_));
1525 if (convert_result) {
1526 VLOG(2) << "Done special moving of convert\n";
1527 if (!ConsumeFuel("conditional_code_motion", [&] {
1528 return "Skipping conditional opt after allowed limit reaching "
1529 "0.\n";
1530 })) {
1531 break;
1532 }
1533 }
1534 changed |= convert_result;
1535 }
1536 }
1537 if (changed) {
1538 HloPassPipeline subpipeline(
1539 "after_conditional_code_motion_after_convert_hoisting");
1540 VLOG(2) << "starting after motion passes: DCE\n";
1541 subpipeline.AddPass<HloDCE>();
1542 subpipeline.AddPass<TupleSimplifier>();
1543 subpipeline.AddPass<HloDCE>();
1544 TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module));
1545 cleanup_changed |= cleanup_changed_now;
1546 }
1547 if (cleanup_changed) {
1548 VLOG(2) << "subpipeline cleanup have modified code\n";
1549 }
1550 return changed;
1551 }
1552
SetDefaultMoveConfig()1553 void ConditionalCodeMotion::SetDefaultMoveConfig() {
1554 VLOG(2) << "search_config_index = " << search_config_index_ << "\n";
1555 VLOG(2) << "search_config_ size = " << search_config_.size() << "\n";
1556 int64_t cur_search_config = (search_config_index_ < 0 ||
1557 search_config_index_ >= search_config_.size())
1558 ? 0
1559 : search_config_[search_config_index_];
1560 enum class TuningOption {
1561 kDoNotTune = 0,
1562 kTuneTransformationDecision = 1,
1563 kTuneReuseModel = 2,
1564 };
1565 TuningOption tuning_option =
1566 (cur_search_config == 0) ? TuningOption::kDoNotTune
1567 : (cur_search_config > 0) ? TuningOption::kTuneTransformationDecision
1568 : TuningOption::kTuneReuseModel;
1569
1570 auto row = HloOpcodeCount();
1571 auto col = row;
1572 VLOG(2) << "Start setting default configuration\n";
1573 reuse_config_.clear();
1574 move_config_.clear();
1575 reuse_config_.reserve(row);
1576 move_config_.reserve(row);
1577 for (int64_t opcode = 0; opcode < row; ++opcode) {
1578 // To save whether an instruction is preferred to be moved.
1579 std::vector<int64> reuse_vec(col, 0);
1580 for (uint32 j = 0; j < col; ++j) {
1581 reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode),
1582 static_cast<HloOpcode>(j));
1583 }
1584 reuse_config_.push_back(reuse_vec);
1585 std::vector<int64> move_vec;
1586 switch (tuning_option) {
1587 case TuningOption::kTuneTransformationDecision:
1588 // Tuning transformation decision --- start with all yes.
1589 // Only a single entry is needed if we don't consider operands of an op
1590 // when searching/tuning transformation decisions.
1591 move_vec.push_back(1);
1592 break;
1593 // Tune the ReusesCarriedBy results only.
1594 case TuningOption::kTuneReuseModel:
1595 case TuningOption::kDoNotTune:
1596 // No tuning --- use the default configuration.
1597 // Use the opcode of first operand to configure default.
1598 move_vec.reserve(col);
1599 for (uint32 j = 0; j < col; ++j) {
1600 move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode),
1601 static_cast<HloOpcode>(j)));
1602 }
1603 break;
1604 }
1605 move_config_.push_back(move_vec);
1606 }
1607 }
1608
1609 // The search configuration is specified using a string in the format of
1610 // 'config1;config2; ...;config_n', where each config_i is in the format of
1611 // 'index,start,max,stride' (four integers separated by comma), which specify
1612 // the index number of the conditional being configured, the index of the first
1613 // transformation decision to flip for the conditional, the max number of
1614 // decisions to flip, and how many decisions to skip in between the flips.
ParseSearchConfiguration(const std::string & search_config)1615 void ConditionalCodeMotion::ParseSearchConfiguration(
1616 const std::string& search_config) {
1617 if (search_config.empty()) {
1618 return;
1619 }
1620 search_config_index_ = 0;
1621 std::vector<std::string> configs = absl::StrSplit(search_config, ';');
1622 for (const std::string& config : configs) {
1623 std::vector<std::string> specs = absl::StrSplit(config, ',');
1624 CHECK_EQ(specs.size(), 4);
1625 int64_t condition_index;
1626 CHECK(absl::SimpleAtoi(specs[0], &condition_index));
1627 auto& cur_config_entry = search_config_map_[condition_index];
1628 int64_t flip_start, max_flip, flip_stride;
1629 CHECK(absl::SimpleAtoi(specs[1], &flip_start));
1630 CHECK(absl::SimpleAtoi(specs[2], &max_flip));
1631 CHECK(absl::SimpleAtoi(specs[3], &flip_stride));
1632 int64_t cur_config = MakeSearchConfig(flip_start, max_flip, flip_stride);
1633 cur_config_entry.push_back(cur_config);
1634 VLOG(2) << "Setting search config " << condition_index << "->" << cur_config
1635 << "\n";
1636 }
1637 }
1638
1639 } // namespace conditional_opt
1640
1641 } // namespace xla
1642