1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/service/call_inliner.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
32 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
33
34 namespace xla {
35
36 namespace m = match;
37 using absl::optional;
38 using hlo_query::ContainsInstrWithOpcode;
39
40 // This is a utility function that removes the given tuple indices from the
41 // while loop init, body, and condition. The final shape returned is still the
42 // same as before.
RemoveDeadTupleIndices(HloInstruction * while_op,absl::flat_hash_set<int64> & used_tuple_indices)43 static StatusOr<HloInstruction*> RemoveDeadTupleIndices(
44 HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) {
45 // Build up maps from the old/new to the new/old tuple indices.
46 std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
47 used_tuple_indices.end());
48 absl::c_sort(new_to_old_tuple_idx);
49
50 HloModule* module = while_op->GetModule();
51 HloComputation* computation = while_op->parent();
52 HloInstruction* while_init = while_op->mutable_operand(0);
53 HloComputation* while_cond = while_op->while_condition();
54 HloComputation* while_body = while_op->while_body();
55 HloInstruction* while_body_root = while_body->root_instruction();
56
57 auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
58
59 absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
60 for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
61 int64 old_idx = new_to_old_tuple_idx[new_idx];
62 old_to_new_tuple_idx[old_idx] = new_idx;
63 VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
64 }
65
66 // Compute the shape of the while op after we remove the dead indices.
67 std::vector<Shape> new_while_tuple_elem_shapes;
68 new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
69 for (int64 old_idx : new_to_old_tuple_idx) {
70 new_while_tuple_elem_shapes.push_back(
71 while_init->shape().tuple_shapes(old_idx));
72 }
73 Shape new_while_shape =
74 ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
75
76 // Returns a map from elements in the computation to new instructions which
77 // replace the old instructions after we remove unused elements from the while
78 // tuple.
79 auto make_while_computation_replacements = [&](const HloComputation* comp) {
80 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
81 replacements;
82
83 auto* param = comp->parameter_instruction(0);
84 replacements.emplace(param, HloInstruction::CreateParameter(
85 0, new_while_shape, param->name()));
86
87 // Materialize param's users, since we're about to add new ones below.
88 std::vector<HloInstruction*> materialized_users(param->users().begin(),
89 param->users().end());
90 for (const auto* user : materialized_users) {
91 // The while body root is handled separately.
92 if (user == while_body_root) {
93 continue;
94 }
95 CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
96 << user->ToString(print_no_metadata);
97
98 int64 old_idx = user->tuple_index();
99 auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
100 if (new_idx_iter != old_to_new_tuple_idx.end()) {
101 // This is a GTE of an index that survives. Replace it.
102 replacements.emplace(
103 user, HloInstruction::CreateGetTupleElement(user->shape(), param,
104 new_idx_iter->second));
105 } else {
106 // This is a GTE of an index that we've removed. Remove it from the
107 // cloned computation.
108 CHECK(user->user_count() == 0 ||
109 user->user_count() == 1 &&
110 user->users().front() == while_body_root)
111 << "Instruction " << user->ToString(print_no_metadata)
112 << " should be unused (except by root of while body), but has "
113 "users: {"
114 << absl::StrJoin(user->users(), ", ",
115 [&](string* out, const HloInstruction* instr) {
116 absl::StrAppend(
117 out, instr->ToString(print_no_metadata));
118 })
119 << "}";
120
121 replacements.emplace(user, nullptr);
122 }
123 }
124 return replacements;
125 };
126
127 // Create the new while condition, body, and init value.
128 std::unique_ptr<HloComputation> new_while_cond =
129 while_cond->CloneWithReplacements(
130 make_while_computation_replacements(while_cond));
131
132 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
133 while_body_replacements = make_while_computation_replacements(while_body);
134 std::vector<HloInstruction*> new_while_body_root_elems;
135 new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
136 for (int64 old_idx : new_to_old_tuple_idx) {
137 new_while_body_root_elems.push_back(
138 while_body_root->mutable_operand(old_idx));
139 }
140 while_body_replacements.emplace(
141 while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
142 std::unique_ptr<HloComputation> new_while_body =
143 while_body->CloneWithReplacements(std::move(while_body_replacements));
144
145 // Add a new while_init instruction that repackages the old while_init
146 // instruction's elements. We rely on the AlgebraicSimplifier and DCE to
147 // clean this up in the common case where while_init is a tuple op. (It's
148 // definitely tuple-shaped, but it's not necessarily a tuple op.)
149 std::vector<HloInstruction*> new_while_init_elems;
150 new_while_init_elems.reserve(new_to_old_tuple_idx.size());
151 for (int64 old_idx : new_to_old_tuple_idx) {
152 new_while_init_elems.push_back(
153 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
154 while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
155 }
156 auto* new_while_init = computation->AddInstruction(
157 HloInstruction::CreateTuple(new_while_init_elems));
158
159 // Create the new while op.
160 auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
161 new_while_shape,
162 module->AddEmbeddedComputation(std::move(new_while_cond)),
163 module->AddEmbeddedComputation(std::move(new_while_body)),
164 new_while_init));
165
166 // Create a tuple op that recreates the output of the old while op. That is,
167 // we transform to
168 //
169 // new_while_init while_init
170 // | |
171 // V |
172 // new_while |
173 // | |
174 // -------| |----
175 // V V
176 // new_tuple
177 // |
178 // V
179 // (orig. users of while op)
180 //
181 // The tuple simplifier will then simplify this if possible, removing
182 // new_tuple and while_init.
183 std::vector<HloInstruction*> new_tuple_elems;
184 const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
185 for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) {
186 auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
187 if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
188 int64 gte_idx = new_tuple_idx_it->second;
189 new_tuple_elems.push_back(
190 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
191 new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
192 gte_idx)));
193 } else {
194 new_tuple_elems.push_back(
195 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
196 while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
197 }
198 }
199 HloInstruction* new_tuple =
200 computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
201 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple));
202
203 return new_while_op;
204 }
205
206 // Tries to remove elements in a while loop's tuple that aren't used within the
207 // loop.
208 //
209 // Specifically, if a loop is tuple-shaped, and there exists some element of
210 // that tuple that is not used by the loop condition and is not used by the loop
211 // body except to pass it to the next iteration of the loop, then we can remove
212 // that element from the loop's tuples.
TryRemoveDeadWhileParams(HloInstruction * while_op)213 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
214 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
215
216 // Don't try this transformation if the while loop isn't removable, since if
217 // it succeeds ultimately we're going to have to replace the old while loop
218 // with a new one.
219 if (!while_op->parent()->IsSafelyRemovable(while_op)) {
220 VLOG(2) << "Can't remove dead parameters from non-removable while op.";
221 return false;
222 }
223
224 HloInstruction* while_init = while_op->mutable_operand(0);
225 HloComputation* while_cond = while_op->while_condition();
226 HloComputation* while_body = while_op->while_body();
227 HloInstruction* while_body_root = while_body->root_instruction();
228
229 if (!while_init->shape().IsTuple()) {
230 VLOG(2) << "While op's carried value isn't tuple shaped.";
231 return false;
232 }
233
234 if (while_body_root->opcode() != HloOpcode::kTuple) {
235 VLOG(2) << "While body's root is not a tuple(...) instruction.";
236 return false;
237 }
238
239 auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
240
241 // Bail if param0 of while_cond or while_body has users which aren't of type
242 // get-tuple-element.
243 for (const HloInstruction* instr : {while_body->parameter_instruction(0),
244 while_cond->parameter_instruction(0)}) {
245 for (const HloInstruction* user : instr->users()) {
246 if (user->opcode() != HloOpcode::kGetTupleElement) {
247 VLOG(2) << "Cowardly refusing to analyze while loop with "
248 << instr->ToString(print_no_metadata)
249 << " used by non-GTE instruction "
250 << user->ToString(print_no_metadata) << " in computation "
251 << instr->parent()->name();
252 return false;
253 }
254 }
255 }
256
257 const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
258 if (tuple_size == 0) {
259 VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
260 "empty.";
261 return false;
262 }
263
264 absl::flat_hash_set<int64> used_tuple_indices;
265 for (HloComputation* comp : {while_body, while_cond}) {
266 // The HLO verifier ensures that while_input's shape matches while_init's
267 // shape, which we verified above is a tuple.
268 HloInstruction* while_input = comp->parameter_instruction(0);
269
270 for (const HloInstruction* user : while_input->users()) {
271 // This user doesn't count if it's only used by the while body's root, and
272 // the root places the tuple element into the same index of the tuple as
273 // it came from. That just amounts to us carrying the variable through
274 // the loop.
275 //
276 // Careful: HloInstruction::operand_index returns the first index the
277 // operand appears in, but it may appear more than once!
278 if (user->user_count() == 1 && user->users().front() == while_body_root &&
279 while_body_root->operand_index(user) == user->tuple_index() &&
280 absl::c_count(while_body_root->operands(), user) == 1) {
281 continue;
282 }
283
284 used_tuple_indices.insert(user->tuple_index());
285 if (used_tuple_indices.size() == tuple_size) {
286 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
287 << " uses all of its inputs; no simplification possible.";
288 return false;
289 }
290 }
291 }
292
293 // If a tuple element is not passed unmodified from the while body's param0
294 // through to the while body's root, count that element as "used", since
295 // removing that element would be observable.
296 for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
297 if (used_tuple_indices.contains(i)) {
298 continue;
299 }
300
301 auto* operand = while_body_root->operand(i);
302 if (operand->opcode() != HloOpcode::kGetTupleElement ||
303 operand->operand(0) != while_body->parameter_instruction(0) ||
304 operand->tuple_index() != i) {
305 VLOG(2) << "Tuple index " << i
306 << " is not passed through loop body unmodified.";
307 used_tuple_indices.insert(i);
308
309 if (used_tuple_indices.size() == tuple_size) {
310 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
311 << " uses all of its inputs; no simplification possible.";
312 return false;
313 }
314 }
315 }
316
317 // If we got here, used_tuple_indices.size() < tuple_size, meaning some
318 // elements of the loop's tuple aren't used by while_body or while_cond.
319 CHECK_LT(used_tuple_indices.size(), tuple_size);
320
321 VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
322 << " elements from tuple of "
323 << while_op->ToString(print_no_metadata);
324
325 TF_ASSIGN_OR_RETURN(while_op,
326 RemoveDeadTupleIndices(while_op, used_tuple_indices));
327
328 return true;
329 }
330
331 // This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes
332 // duplicates by replacing them with tuple_index, followed by a call to
333 // RemoveDeadTupleIndices.
TryRemoveRepeatedWhileTupleIndicesHelper(HloInstruction * while_op,const int64 tuple_index,absl::flat_hash_set<int64> & duplicates)334 static StatusOr<HloInstruction*> TryRemoveRepeatedWhileTupleIndicesHelper(
335 HloInstruction* while_op, const int64 tuple_index,
336 absl::flat_hash_set<int64>& duplicates) {
337 HloComputation* while_cond = while_op->while_condition();
338 HloComputation* while_body = while_op->while_body();
339 HloInstruction* while_init = while_op->mutable_operand(0);
340
341 VLOG(2) << "while_init " << while_init->ToString() << " operands "
342 << while_init->operand_count();
343 VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString()
344 << " operands " << while_body->root_instruction()->operand_count();
345
346 // Change the loop body and condition such that uses of the duplicates are
347 // replaced with the original tuple element.
348 for (HloComputation* comp : {while_body, while_cond}) {
349 auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
350 comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index),
351 comp->parameter_instruction(0), tuple_index));
352
353 std::vector<HloInstruction*> instrs_to_replace;
354 for (auto* instr : comp->instructions()) {
355 if (instr->opcode() == HloOpcode::kGetTupleElement &&
356 duplicates.contains(instr->tuple_index()) &&
357 instr->operand(0) == comp->parameter_instruction(0)) {
358 instrs_to_replace.push_back(instr);
359 }
360 }
361
362 for (auto instr : instrs_to_replace) {
363 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get));
364 }
365 }
366
367 // We know which tuple indices are useful; i.e, those which aren't duplicates.
368 absl::flat_hash_set<int64> used_tuple_indices;
369 for (int index = 0; index < while_init->shape().tuple_shapes_size();
370 ++index) {
371 if (!duplicates.count(index)) {
372 used_tuple_indices.insert(index);
373 }
374 }
375
376 // Remove the duplicate tuple elements.
377 TF_ASSIGN_OR_RETURN(while_op,
378 RemoveDeadTupleIndices(while_op, used_tuple_indices));
379
380 return while_op;
381 }
382
383 // If the while loop init passes the same values to several tuple indices, and
384 // if the body keeps on passing them through, we can remove the duplicates.
TryRemoveRepeatedWhileTupleIndices(HloInstruction * while_op)385 static StatusOr<bool> TryRemoveRepeatedWhileTupleIndices(
386 HloInstruction* while_op) {
387 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
388
389 int index_to_investigate = 0;
390 // Don't try this transformation if the while loop isn't removable, since if
391 // it succeeds ultimately we're going to have to replace the old while loop
392 // with a new one.
393 if (!while_op->parent()->IsSafelyRemovable(while_op)) {
394 VLOG(2) << "Can't remove dead parameters from non-removable while op.";
395 return false;
396 }
397
398 HloInstruction* while_init = while_op->mutable_operand(0);
399 HloComputation* while_cond = while_op->while_condition();
400 HloComputation* while_body = while_op->while_body();
401 HloInstruction* while_body_root = while_body->root_instruction();
402
403 if (!while_init->shape().IsTuple()) {
404 VLOG(2) << "While op's carried value isn't tuple shaped.";
405 return false;
406 }
407
408 bool changed = false;
409 while (index_to_investigate < while_init->shape().tuple_shapes_size()) {
410 if (!while_init->shape().IsTuple() ||
411 while_init->opcode() != HloOpcode::kTuple) {
412 VLOG(2) << "While op's carried value isn't tuple shaped.";
413 return false;
414 }
415
416 if (while_body_root->opcode() != HloOpcode::kTuple) {
417 VLOG(2) << "While body's root is not a tuple(...) instruction.";
418 return false;
419 }
420
421 auto& while_shape = while_init->shape();
422 VLOG(2) << "Iterating " << index_to_investigate;
423
424 absl::flat_hash_set<int64> duplicates;
425 auto* pivot_init_elem = while_init->operand(index_to_investigate);
426 auto* pivot_body_elem = while_body_root->operand(index_to_investigate);
427 if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement &&
428 pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) {
429 if (pivot_body_elem->tuple_index() != index_to_investigate) {
430 VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() "
431 << pivot_body_elem->tuple_index() << " index_to_investigate "
432 << index_to_investigate;
433 index_to_investigate++;
434 continue;
435 }
436 } else {
437 index_to_investigate++;
438 continue;
439 }
440
441 // Look from index_to_investigate onwards to see if it is repeated.
442 for (int64 i = index_to_investigate + 1;
443 i < while_shape.tuple_shapes_size(); ++i) {
444 auto* init_elem = while_init->operand(i);
445 auto* body_elem = while_body_root->operand(i);
446 if (body_elem->opcode() == HloOpcode::kGetTupleElement &&
447 body_elem->operand(0) == while_body->parameter_instruction(0)) {
448 if (body_elem->tuple_index() != i) {
449 VLOG(2) << "Mismatch between body_elem->tuple_index() "
450 << body_elem->tuple_index() << " i " << i;
451 continue;
452 }
453 } else {
454 continue;
455 }
456
457 if (pivot_init_elem == init_elem) {
458 VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem "
459 << pivot_init_elem->ToString();
460 VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem "
461 << pivot_body_elem->ToString();
462 duplicates.insert(i);
463 }
464 }
465
466 // If duplicates are found, call the helper to remove them.
467 if (!duplicates.empty()) {
468 VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init "
469 << pivot_init_elem->ToString();
470 TF_ASSIGN_OR_RETURN(while_op,
471 TryRemoveRepeatedWhileTupleIndicesHelper(
472 while_op, index_to_investigate, duplicates));
473 changed = true;
474 VLOG(2) << "Changed while_op " << while_op->ToString()
475 << " while_op operand count " << while_op->operand_count();
476 // Update the while loop variables so we can continue looking for
477 // duplicates of a different index.
478 while_init = while_op->mutable_operand(0);
479 while_cond = while_op->while_condition();
480 while_body = while_op->while_body();
481 while_body_root = while_body->root_instruction();
482 }
483 index_to_investigate++;
484 }
485
486 return changed;
487 }
488
489 // Removes each loop parameter (i.e. member of the while loop tuple) that is a
490 // constant and is the same in the while loop body and the while loop init.
TryRemoveConstantParams(HloInstruction * while_op)491 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
492 HloModule* module = while_op->GetModule();
493 HloComputation* computation = while_op->parent();
494 auto* while_init = while_op->mutable_operand(0);
495 auto* while_body = while_op->while_body();
496 auto* while_cond = while_op->while_condition();
497 auto* while_body_root = while_body->root_instruction();
498 if (while_init->opcode() != HloOpcode::kTuple ||
499 while_body_root->opcode() != HloOpcode::kTuple) {
500 return false;
501 }
502
503 TF_RET_CHECK(while_cond->num_parameters() == 1);
504 TF_RET_CHECK(while_body->num_parameters() == 1);
505 TF_RET_CHECK(
506 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
507
508 absl::flat_hash_set<int64> constant_tuple_indices;
509 const auto& while_shape = while_init->shape();
510 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
511 auto* init_elem = while_init->operand(i);
512 auto* body_elem = while_body_root->operand(i);
513 if (init_elem->opcode() == HloOpcode::kConstant &&
514 body_elem->opcode() == HloOpcode::kConstant &&
515 init_elem->literal() == body_elem->literal()) {
516 constant_tuple_indices.insert(i);
517 }
518 }
519
520 if (constant_tuple_indices.empty()) {
521 return false;
522 }
523
524 // OK, we found some constant elements of the while parameter! Eliminate
525 // them.
526 std::vector<Shape> new_while_shape_elems;
527 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
528 if (!constant_tuple_indices.count(i)) {
529 new_while_shape_elems.push_back(while_shape.tuple_shapes(i));
530 }
531 }
532 Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems);
533
534 // `new_instrs` holds instructions created outside of a computation for
535 // cloning. Elements added here just need to live until the end of the
536 // relevant CloneWithReplacement call.
537 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
538 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
539 new_instrs.push_back(std::move(instr));
540 return new_instrs.back().get();
541 };
542
543 // Returns a new tuple without the elements of constant_tuple_indices.
544 auto remove_constant_elems = [&](HloInstruction* instr) {
545 CHECK(ShapeUtil::Compatible(instr->shape(), while_shape));
546
547 std::vector<HloInstruction*> tuple_elems;
548 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
549 if (!constant_tuple_indices.count(i)) {
550 tuple_elems.push_back(
551 add_new_instr(HloInstruction::CreateGetTupleElement(
552 while_shape.tuple_shapes(i), instr, i)));
553 }
554 }
555 return HloInstruction::CreateTuple(tuple_elems);
556 };
557
558 auto add_constant_elems = [&](HloInstruction* instr) {
559 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
560
561 std::vector<HloInstruction*> tuple_elems;
562 int64 j = 0;
563 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
564 if (constant_tuple_indices.count(i)) {
565 tuple_elems.push_back(while_init->mutable_operand(i));
566 } else {
567 tuple_elems.push_back(
568 add_new_instr(HloInstruction::CreateGetTupleElement(
569 while_shape.tuple_shapes(i), instr, j)));
570 ++j;
571 }
572 }
573 return HloInstruction::CreateTuple(tuple_elems);
574 };
575
576 // Special case: constant_tuple_indices covers the whole while parameter, so
577 // the new while shape is the empty tuple. In this case, the value of the
578 // while loop is simply equal to the value of `init`.
579 //
580 // It's unfortunate to special-case this, but it's simpler than the
581 // alternative. The problem is that if our while parameter has no
582 // non-constant elems, the tuple returned by `add_constant_elems` won't depend
583 // on instr (the loop body/cond parameter), and therefore
584 // CloneWithReplacementPairs will *leave the parameter out entirely*, creating
585 // invalid HLO.
586 if (ShapeUtil::IsEmptyTuple(new_while_shape)) {
587 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init));
588 return true;
589 }
590
591 std::unique_ptr<HloComputation> new_while_cond =
592 while_cond->CloneWithReplacementPairs({
593 while_cond->parameter_instruction(0),
594 add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
595 0, new_while_shape,
596 while_cond->parameter_instruction(0)->name()))),
597 });
598
599 std::unique_ptr<HloComputation> new_while_body =
600 while_body->CloneWithReplacementPairs(
601 {
602 while_body->parameter_instruction(0),
603 add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
604 0, new_while_shape,
605 while_cond->parameter_instruction(0)->name()))),
606 },
607 {
608 while_body->root_instruction(),
609 remove_constant_elems(
610 add_new_instr(while_body->root_instruction()->Clone())),
611 });
612
613 // Create the final while loop, and add any new instructions created to
614 // `computation`.
615 new_instrs.clear();
616 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
617 while_op,
618 add_constant_elems(
619 computation->AddInstruction(HloInstruction::CreateWhile(
620 new_while_shape,
621 module->AddEmbeddedComputation(std::move(new_while_cond)),
622 module->AddEmbeddedComputation(std::move(new_while_body)),
623 add_new_instr(remove_constant_elems(while_init)))))));
624 for (auto& instr : new_instrs) {
625 computation->AddInstruction(std::move(instr));
626 }
627 return true;
628 }
629
630 // Tries to remove a while loop from the graph.
631 //
632 // - Loops with trip count of 0 can be replaced by the loop's "init" value.
633 // - Loops with trip count of 1 can be replaced by the loop's body, with the
634 // loop itself removed.
635 //
636 // Returns true if it made a change to the graph.
TryRemoveWhileLoop(HloInstruction * while_op)637 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
638 // Cowardly refuse to remove loops that are not removable. In practice, this
639 // means that we can't remove loops that have control predecessors/successors.
640 if (!while_op->parent()->IsSafelyRemovable(while_op)) {
641 VLOG(2) << "Not attempting to remove while loop that is not removable: "
642 << while_op->ToShortString();
643 return false;
644 }
645
646 // Refuse to remove while loops with a condition that contain side-effects,
647 // because removing a while loop is tantamount to removing its condition.
648 //
649 // TODO(jlebar): This is conservative: We could instead just run the while
650 // condition once (trip-count == 0) or twice (trip-count == 1).
651 if (while_op->while_condition()->HasSideEffect()) {
652 VLOG(2) << "Not attempting to remove while loop whose condition contains "
653 "side-effecting instructions: "
654 << while_op->ToShortString();
655 return false;
656 }
657
658 // Remove while loops with static trip count of 0.
659 optional<int64> trip_count =
660 ComputeWhileLoopTripCount(while_op, /*max_brute_force_iters=*/1);
661 if (trip_count && *trip_count == 0) {
662 // The loop never executes, so the value of the loop is the value of its
663 // "init" operand.
664 auto computation = while_op->parent();
665
666 // Remove while_op (i.e., call ReplaceInstruction rather than
667 // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in
668 // a loop without an intervening DCE, we don't try to re-remove the loop.
669 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
670 while_op, while_op->mutable_operand(0)));
671 return true;
672 }
673
674 // Transform while loops with static trip count of 1 into a call op, then
675 // inline the call.
676 if (trip_count && *trip_count == 1) {
677 // Do not simplify the loop away when there is a side-effectful op,
678 // otherwise the infeed op may not inherit the data dependency from
679 // the while loop.
680 //
681 // Example: while_body (param_a) {
682 // param_a = parameter(0)
683 // infeed2 = infeed()
684 // }
685 //
686 // infeed1 = ...
687 // while = while(infeed1), body=while_body // infeed2 has implicit
688 // dependency on infeed1.
689 //
690 // After simplification:
691 //
692 // infeed1 = ...
693 // infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
694 // // can be scheduled after infeed2.
695 //
696 bool has_side_effects = absl::c_any_of(
697 while_op->called_computations(), [](const HloComputation* computation) {
698 return computation->HasSideEffect();
699 });
700 if (!has_side_effects) {
701 auto computation = while_op->parent();
702 auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
703 while_op->shape(), while_op->operands(), while_op->while_body()));
704 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
705 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
706 CallInliner::Inline(call_op));
707 (void)inlined_instructions_map;
708 return true;
709 } else {
710 VLOG(2) << "Not attempting to simplify while loop because it contains a "
711 "side-effecting node: "
712 << while_op->ToShortString();
713 }
714 }
715 return false;
716 }
717
TryPropagateConstant(HloInstruction * while_op)718 static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
719 auto while_init = while_op->operand(0);
720 if (while_init->opcode() != HloOpcode::kTuple) {
721 return false;
722 }
723
724 auto while_body = while_op->while_body();
725 auto while_body_root = while_body->root_instruction();
726 if (while_body_root->opcode() != HloOpcode::kTuple) {
727 return false;
728 }
729
730 auto while_body_param = while_body->parameter_instruction(0);
731 const HloInstruction::InstructionVector& root_operands =
732 while_body_root->operands();
733
734 // Find the loop invariant tuple elements with scalar constant init value and
735 // build a map from the tuple element index to the constant value. Limit this
736 // to scalar constant values because propagating array constants can regress
737 // performance by forcing us to copy constants.
738 absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
739 for (int i = 0; i < root_operands.size(); i++) {
740 const HloInstruction* init_tuple_elem = nullptr;
741 if (Match(root_operands[i],
742 m::GetTupleElement(m::Op().Is(while_body_param), i)
743 .WithShape(m::Shape().IsScalar())) &&
744 Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
745 VLOG(3) << "Found loop invariant tuple element " << i << " "
746 << init_tuple_elem->ToString();
747 index_to_constant[i] = init_tuple_elem;
748 }
749 }
750
751 if (index_to_constant.empty()) {
752 return false;
753 }
754
755 // Replace the use of each constant tuple element in the loop_condition and
756 // loop_body with the corresponding constant value.
757 auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
758 HloInstruction* param = computation->parameter_instruction(0);
759 bool changed = false;
760 for (auto instr : param->users()) {
761 // Since only a while-loop with a tuple result reaches here, we can safely
762 // assume that `param` is a tuple and the first operand of the
763 // GetTupleElement instruction is a use of `param`.
764 if (instr->opcode() == HloOpcode::kGetTupleElement) {
765 VLOG(3) << "tuple index " << instr->tuple_index() << " "
766 << instr->ToString();
767 auto iter = index_to_constant.find(instr->tuple_index());
768 if (iter != index_to_constant.end()) {
769 const HloInstruction* hlo_constant = (*iter).second;
770 VLOG(3) << "Replace use of " << instr->ToString() << " with "
771 << hlo_constant->ToString();
772 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
773 computation->AddInstruction(hlo_constant->Clone())));
774 changed = true;
775 }
776 }
777 }
778 return changed;
779 };
780
781 TF_ASSIGN_OR_RETURN(bool changed_cond,
782 propagate_constant(while_op->while_condition()));
783 TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
784
785 return changed_cond || changed_body;
786 }
787
788 // Converts a flat list of instructions into a tuple of the desired shape. For
789 // example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns
790 // a tuple of value ((A, B), C).
791 //
792 // desired_shape must be a tuple. (This precondition allows us to return a
793 // unique_ptr rather than a raw ptr.)
UnflattenTupleInstr(absl::Span<HloInstruction * > instrs,const Shape & desired_shape,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)794 static std::unique_ptr<HloInstruction> UnflattenTupleInstr(
795 absl::Span<HloInstruction*> instrs, const Shape& desired_shape,
796 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
797 CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape);
798
799 // For each child shape in `desired_shape`, slice out the correct number of
800 // `instrs` and call UnflattenTupleInstr recursively. At each step we remove
801 // elements from `instrs` so that it only contains instructions we have not
802 // yet processed.
803 std::vector<HloInstruction*> elems;
804 for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) {
805 const Shape& subshape = desired_shape.tuple_shapes(i);
806 if (!subshape.IsTuple()) {
807 elems.push_back(instrs[0]);
808 instrs.remove_prefix(1);
809 continue;
810 }
811
812 // Count the number of leaf nodes underneath desired_shape[i].
813 int64 num_leaves = 0;
814 ShapeUtil::ForEachSubshape(
815 subshape, [&](const Shape& s, const ShapeIndex& /*index*/) {
816 if (!s.IsTuple()) {
817 ++num_leaves;
818 }
819 });
820
821 std::unique_ptr<HloInstruction> subinstr =
822 UnflattenTupleInstr(instrs.subspan(0, num_leaves),
823 desired_shape.tuple_shapes(i), new_instrs);
824 elems.push_back(subinstr.get());
825 new_instrs->push_back(std::move(subinstr));
826 instrs.remove_prefix(num_leaves);
827 }
828 return HloInstruction::CreateTuple(elems);
829 }
830
831 // Builds a vector whose elements are the values in the flattened tuple for
832 // `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the
833 // vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C).
GetFlatTupleElems(HloInstruction * instr,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)834 static std::vector<HloInstruction*> GetFlatTupleElems(
835 HloInstruction* instr,
836 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
837 const auto& shape = instr->shape();
838 if (!shape.IsTuple()) {
839 return {instr};
840 }
841 std::vector<HloInstruction*> elems;
842 for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
843 const Shape& subshape = shape.tuple_shapes(i);
844 new_instrs->push_back(
845 HloInstruction::CreateGetTupleElement(subshape, instr, i));
846 auto* gte = new_instrs->back().get();
847 auto flattened_subshape = GetFlatTupleElems(gte, new_instrs);
848 elems.insert(elems.end(), flattened_subshape.begin(),
849 flattened_subshape.end());
850 }
851 return elems;
852 }
853
TryFlattenNestedTuples(HloInstruction * while_op)854 static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
855 HloModule* module = while_op->GetModule();
856 HloComputation* computation = while_op->parent();
857 auto* while_init = while_op->mutable_operand(0);
858 auto* while_body = while_op->while_body();
859 auto* while_cond = while_op->while_condition();
860 auto* while_body_root = while_body->root_instruction();
861 if (while_init->opcode() != HloOpcode::kTuple ||
862 while_body_root->opcode() != HloOpcode::kTuple) {
863 return false;
864 }
865
866 TF_RET_CHECK(while_cond->num_parameters() == 1);
867 TF_RET_CHECK(while_body->num_parameters() == 1);
868 TF_RET_CHECK(
869 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
870 Shape while_shape = while_init->shape();
871 if (!ShapeUtil::IsNestedTuple(while_shape)) {
872 return false;
873 }
874
875 std::vector<Shape> flattened_shape_elems;
876 ShapeUtil::ForEachSubshape(while_shape,
877 [&](const Shape& s, const ShapeIndex& /*index*/) {
878 if (!s.IsTuple()) {
879 flattened_shape_elems.push_back(s);
880 }
881 });
882 Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems);
883
884 // `new_instrs` holds instructions created outside of a computation for
885 // cloning. Elements added here just need to live until the end of the
886 // relevant CloneWithReplacement call.
887 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
888 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
889 new_instrs.push_back(std::move(instr));
890 return new_instrs.back().get();
891 };
892
893 auto nested = [&](HloInstruction* instr) {
894 std::vector<HloInstruction*> gtes;
895 const Shape& flat_shape = instr->shape();
896 for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) {
897 gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement(
898 flat_shape.tuple_shapes(i), instr, i)));
899 }
900 auto nested_instr =
901 UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs);
902 CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape))
903 << ShapeUtil::HumanString(nested_instr->shape()) << " vs "
904 << ShapeUtil::HumanString(while_shape);
905 return nested_instr;
906 };
907
908 auto flattened = [&](HloInstruction* instr) {
909 return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs));
910 };
911
912 // Create a new while-condition computation, where parameter 0 has flat shape
913 // but all uses of it go through the nested shape.
914 std::unique_ptr<HloComputation> new_while_cond =
915 while_cond->CloneWithReplacementPairs({
916 while_cond->parameter_instruction(0),
917 nested(add_new_instr(HloInstruction::CreateParameter(
918 0, flattened_shape,
919 while_cond->parameter_instruction(0)->name()))),
920 });
921
922 // Create a new while-body computation, where parameter 0 has a flat shape and
923 // all uses of it go through the nested shape, and where the root has a flat
924 // shape constructed from the old nested root.
925 std::unique_ptr<HloComputation> new_while_body =
926 while_body->CloneWithReplacementPairs(
927 {
928 while_body->parameter_instruction(0),
929 nested(add_new_instr(HloInstruction::CreateParameter(
930 0, flattened_shape,
931 while_body->parameter_instruction(0)->name()))),
932 },
933 {
934 while_body->root_instruction(),
935 flattened(add_new_instr(while_body->root_instruction()->Clone())),
936 });
937
938 // Create the final while loop, and add any new instructions created to
939 // `computation`.
940 new_instrs.clear();
941 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
942 while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile(
943 flattened_shape,
944 module->AddEmbeddedComputation(std::move(new_while_cond)),
945 module->AddEmbeddedComputation(std::move(new_while_body)),
946 computation->AddInstruction(flattened(while_init)))))));
947 for (auto& instr : new_instrs) {
948 computation->AddInstruction(std::move(instr));
949 }
950 return true;
951 }
952
953 // Tries to merge loop induction variables of a given type.
954 //
955 // In this pass we're only concerned with elements of the loop's tuple that
956 // are effective-scalars of type `elem_ty`. Some terminology:
957 //
958 // - The trip counter is the first element of the loop's tuple that starts at
959 // 0 and does x++ on each iteration.
960 //
961 // - An induction variable is an element of the loop's tuple that is not the
962 // trip counter and does `x += <constant>` on each iteration of the loop.
963 // Negative constants are OK.
964 //
965 // This pass adds a trip counter if one isn't already present, then replaces
966 // each induction variable with
967 //
968 // <initial_value> + <trip_count> * <constant>.
969 //
970 // This reduces the number of scalar operations in the loop, which is important
971 // e.g. on GPUs, where each scalar operation is nontrivially expensive because
972 // it's a separate kernel launch.
973 //
974 // Returns the new loop if a change was made, or null if no change was made.
975 // Note that the new loop is not a valid replacement for the old loop; it may
976 // need to be wrapped in a tuple that changes its shape. We return the loop
977 // itself so that you can call TryMergeInductionVariables in a loop, once for
978 // each integral type elem_ty.
TryMergeInductionVariables(HloInstruction * while_op,PrimitiveType elem_ty)979 static StatusOr<HloInstruction*> TryMergeInductionVariables(
980 HloInstruction* while_op, PrimitiveType elem_ty) {
981 CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty);
982 HloModule* module = while_op->GetModule();
983 HloComputation* computation = while_op->parent();
984 auto* while_init = while_op->mutable_operand(0);
985 auto* while_body = while_op->while_body();
986 auto* while_cond = while_op->while_condition();
987 auto* while_body_root = while_body->root_instruction();
988 if (while_init->opcode() != HloOpcode::kTuple ||
989 while_body_root->opcode() != HloOpcode::kTuple) {
990 return nullptr;
991 }
992
993 TF_RET_CHECK(while_cond->num_parameters() == 1);
994 TF_RET_CHECK(while_body->num_parameters() == 1);
995 TF_RET_CHECK(
996 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
997 Shape while_shape = while_init->shape();
998
999 // The tuple index of the trip counter, if one is present.
1000 absl::optional<int64> trip_counter;
1001 // Maps the tuple index of each induction variable to its constant increment.
1002 absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
1003 for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
1004 HloInstruction* constant;
1005 if (!Match(while_body_root->mutable_operand(i),
1006 m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
1007 m::ConstantScalar(&constant))
1008 .WithShape(m::Shape().WithElementType(elem_ty)))) {
1009 continue;
1010 }
1011 if (!trip_counter && constant->literal().IsAll(1) &&
1012 while_init->operand(i)->IsConstant() &&
1013 while_init->operand(i)->literal().IsAll(0)) {
1014 VLOG(10) << "Found existing trip counter at index " << i;
1015 trip_counter = i;
1016 } else {
1017 VLOG(10) << "Found induction variable at index " << i;
1018 induction_vars.emplace(i, Cast<HloConstantInstruction>(constant));
1019 }
1020 }
1021
1022 // There's only something to simplify if we can either:
1023 //
1024 // - combine one or more induction vars with an existing trip counter, or
1025 // - replace two or more induction variables with a new trip counter.
1026 //
1027 // Put another way, there's only something to simplify if the number of
1028 // induction vars plus the number of existing trip counters (0 or 1) is >= 2.
1029 if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) {
1030 return nullptr;
1031 }
1032
1033 // OK, we're going to do the transformation! Set up some helpers.
1034
1035 // `new_instrs` holds instructions created outside of a computation for
1036 // cloning. Elements added here just need to live until the end of the
1037 // relevant CloneWithReplacement call.
1038 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
1039 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
1040 new_instrs.push_back(std::move(instr));
1041 return new_instrs.back().get();
1042 };
1043
1044 auto add_binary_op = [&](const Shape& shape, HloOpcode opcode,
1045 HloInstruction* lhs, HloInstruction* rhs) {
1046 // Reshape lhs/rhs to the output shape if necessary. This deals with the
1047 // fact that induction variables need only be effective scalars, not true
1048 // scalars.
1049 if (!ShapeUtil::Compatible(shape, lhs->shape())) {
1050 lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs));
1051 }
1052 if (!ShapeUtil::Compatible(shape, rhs->shape())) {
1053 rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs));
1054 }
1055 return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs));
1056 };
1057
1058 auto add_gte = [&](HloInstruction* src, int64 idx) {
1059 return add_new_instr(HloInstruction::CreateGetTupleElement(
1060 src->shape().tuple_shapes(idx), src, idx));
1061 };
1062
1063 // Our new while loop will have the same shape as the old while loop, except
1064 // we'll add a trip counter to the end if it wasn't originally present.
1065 Shape new_while_shape = while_shape;
1066 bool added_trip_counter = false;
1067 if (!trip_counter) {
1068 VLOG(10) << "Adding new trip counter to end of loop's tuple.";
1069 trip_counter = new_while_shape.tuple_shapes_size();
1070 *new_while_shape.add_tuple_shapes() =
1071 ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{});
1072 added_trip_counter = true;
1073 }
1074
1075 // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with
1076 // shape `while_body->shape()` and where the induction variables are "reified"
1077 // (i.e. they have value <init> + <counter> * <constant>).
1078 auto convert_to_old_form = [&](HloInstruction* instr) {
1079 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
1080 std::vector<HloInstruction*> tuple_elems;
1081 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1082 const auto& elem_shape = while_shape.tuple_shapes(i);
1083 if (!induction_vars.count(i)) {
1084 tuple_elems.push_back(add_gte(instr, i));
1085 continue;
1086 }
1087 tuple_elems.push_back(add_binary_op(
1088 elem_shape, HloOpcode::kAdd, add_gte(instr, i),
1089 add_binary_op(elem_shape, HloOpcode::kMultiply,
1090 add_gte(instr, *trip_counter),
1091 add_new_instr(induction_vars.at(i)->Clone()))));
1092 }
1093 return HloInstruction::CreateTuple(tuple_elems);
1094 };
1095
1096 // Converts `root` into a tuple of the "new" form -- that is, to a tuple with
1097 // shape `new_while_shape` and where the induction variables (but not trip
1098 // counters) are replaced with their unchanging <loop_body_param> values.
1099 auto convert_to_new_form = [&](HloInstruction* old_root,
1100 HloParameterInstruction* loop_body_param) {
1101 CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape));
1102 std::vector<HloInstruction*> tuple_elems;
1103
1104 // In the new form, induction variables come from `init`, everything else
1105 // (including the trip counter if it's not one we created ourselves) comes
1106 // from the `root` tuple unmodified.
1107 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1108 tuple_elems.push_back(
1109 add_gte((induction_vars.count(i) ? loop_body_param : old_root), i));
1110 }
1111 // If we created a trip counter ourselves, add 1 to it in the next
1112 // iteration.
1113 if (added_trip_counter) {
1114 tuple_elems.push_back(add_binary_op(
1115 new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd,
1116 add_gte(loop_body_param, *trip_counter),
1117 add_new_instr(
1118 HloInstruction::CreateConstant(LiteralUtil::One(elem_ty)))));
1119 }
1120
1121 return HloInstruction::CreateTuple(tuple_elems);
1122 };
1123
1124 // Creates a new init tuple, which is the same as the old init tuple except if
1125 // we added a trip counter, it's set to 0.
1126 auto get_new_while_init = [&](HloInstruction* init) {
1127 CHECK(ShapeUtil::Compatible(init->shape(), while_shape));
1128 if (!added_trip_counter) {
1129 return init;
1130 }
1131 std::vector<HloInstruction*> tuple_elems;
1132 for (int64 i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1133 tuple_elems.push_back(add_gte(init, i));
1134 }
1135 tuple_elems.push_back(add_new_instr(
1136 HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty))));
1137 return add_new_instr(HloInstruction::CreateTuple(tuple_elems));
1138 };
1139
1140 std::unique_ptr<HloComputation> new_while_cond =
1141 while_cond->CloneWithReplacementPairs({
1142 while_cond->parameter_instruction(0),
1143 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1144 0, new_while_shape,
1145 while_cond->parameter_instruction(0)->name()))),
1146 });
1147
1148 // Creating the new while body proceeds in two steps. First we convert the
1149 // users of the parameter to the old form. Then as a second
1150 // CloneWithReplacement operation we convert the root to the new form. We
1151 // have to do this in two steps because the new root needs to use the new
1152 // param0, and during the first clone operation, only the *old-form* param0 is
1153 // accessible.
1154 //
1155 // We have to add temp_new_while_body to the module because cloning a
1156 // computation touches the module (to get its NameUniquer).
1157 HloComputation* temp_new_while_body =
1158 module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({
1159 while_body->parameter_instruction(0),
1160 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1161 0, new_while_shape,
1162 while_body->parameter_instruction(0)->name()))),
1163 }));
1164 std::unique_ptr<HloComputation> new_while_body =
1165 temp_new_while_body->CloneWithReplacementPairs({
1166 temp_new_while_body->root_instruction(),
1167 convert_to_new_form(
1168 add_new_instr(temp_new_while_body->root_instruction()->Clone()),
1169 Cast<HloParameterInstruction>(
1170 temp_new_while_body->parameter_instruction(0))),
1171 });
1172 TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body));
1173
1174 // Create the final while loop, and add any new instructions created to
1175 // `computation`.
1176 new_instrs.clear();
1177 auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile(
1178 new_while_shape,
1179 module->AddEmbeddedComputation(std::move(new_while_cond)),
1180 module->AddEmbeddedComputation(std::move(new_while_body)),
1181 get_new_while_init(while_init)));
1182 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
1183 while_op, convert_to_old_form(new_while)));
1184 for (auto& instr : new_instrs) {
1185 computation->AddInstruction(std::move(instr));
1186 }
1187 return new_while;
1188 }
1189
Run(HloModule * module)1190 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
1191 XLA_VLOG_LINES(3,
1192 "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
1193 bool changed = false;
1194
1195 // Gather all the while ops in our module. We do this ahead of time so we
1196 // don't have to worry about mutating the lists of computations or
1197 // instructions while we iterate.
1198 std::vector<HloInstruction*> while_ops;
1199 for (auto* comp : module->computations()) {
1200 for (auto* instr : comp->instructions()) {
1201 if (instr->opcode() == HloOpcode::kWhile) {
1202 while_ops.push_back(instr);
1203 }
1204 }
1205 }
1206
1207 for (HloInstruction* while_op : while_ops) {
1208 // We can't remove while loops that contain send/recv nodes, because we rely
1209 // on the particular loop structure around the node matching on the send and
1210 // recv sides. Other while simplifications require us to remove the loop
1211 // and replace it with a new one, so we can't do that either.
1212 if (ContainsInstrWithOpcode(while_op->while_body(),
1213 {HloOpcode::kSend, HloOpcode::kSendDone,
1214 HloOpcode::kRecv, HloOpcode::kRecvDone}) ||
1215 ContainsInstrWithOpcode(while_op->while_condition(),
1216 {HloOpcode::kSend, HloOpcode::kSendDone,
1217 HloOpcode::kRecv, HloOpcode::kRecvDone})) {
1218 VLOG(2) << "Not attempting to simplify while loop because it contains a "
1219 "send/recv node: "
1220 << while_op->ToShortString();
1221 continue;
1222 }
1223
1224 TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
1225 changed |= result;
1226
1227 TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
1228 changed |= result;
1229
1230 if (result) {
1231 // Don't continue simplifying after successfully removing the while loop
1232 // -- that would result in use-after-free nastiness.
1233 continue;
1234 }
1235
1236 // TODO(b/119281462): Cowardly refuse to perform any of the following
1237 // optimizations in the presence of kDomain instructions. It seems that
1238 // modifying a while loop's tuple doesn't work when kDomain is present.
1239 if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) ||
1240 ContainsInstrWithOpcode(while_op->while_condition(),
1241 {HloOpcode::kDomain})) {
1242 continue;
1243 }
1244
1245 // Each of the optimizations below modifies the while loop itself if it's
1246 // successful, meaning that `while_op` is no longer valid after one of these
1247 // transformations returns true.
1248
1249 TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op));
1250 changed |= result;
1251 if (result) {
1252 continue;
1253 }
1254
1255 TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
1256 changed |= result;
1257 if (result) {
1258 continue;
1259 }
1260
1261 TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
1262
1263 changed |= result;
1264 if (result) {
1265 continue;
1266 }
1267
1268 TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op));
1269 changed |= result;
1270 if (result) {
1271 continue;
1272 }
1273
1274 bool merged_induction_vars = false;
1275 // Notably missing from this list are S16 and U16. These don't currently
1276 // work because S/U16 literals are not implemented.
1277 for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) {
1278 TF_ASSIGN_OR_RETURN(auto* new_while_op,
1279 TryMergeInductionVariables(while_op, elem_ty));
1280 if (new_while_op) {
1281 while_op = new_while_op;
1282 changed = true;
1283 merged_induction_vars = true;
1284 }
1285 }
1286 if (merged_induction_vars) {
1287 continue;
1288 }
1289 }
1290
1291 XLA_VLOG_LINES(3,
1292 "WhileLoopSimplifier::Run(), after:\n" + module->ToString());
1293 return changed;
1294 }
1295
1296 } // namespace xla
1297