1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/service/call_inliner.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
32 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
33
34 namespace xla {
35
36 namespace m = match;
37 using absl::optional;
38 using hlo_query::ContainsInstrWithOpcode;
39
40 // A helper function that copy the raw JSON-encoded backend config from the old
41 // while op to the new one.
CopyBackendConfig(HloInstruction * old_while_op,HloInstruction * new_while_op)42 void CopyBackendConfig(HloInstruction* old_while_op,
43 HloInstruction* new_while_op) {
44 new_while_op->set_raw_backend_config_string(
45 old_while_op->raw_backend_config_string());
46 }
47
48 // A helper function that copy the frontend attributes from the old while op to
49 // the new one.
CopyFrontendAttributes(HloInstruction * old_while_op,HloInstruction * new_while_op)50 void CopyFrontendAttributes(HloInstruction* old_while_op,
51 HloInstruction* new_while_op) {
52 new_while_op->add_frontend_attributes(old_while_op->frontend_attributes());
53 }
54
55 // This is a utility function that removes the given tuple indices from the
56 // while loop init, body, and condition. The final shape returned is still the
57 // same as before.
RemoveDeadTupleIndices(HloInstruction * while_op,absl::flat_hash_set<int64> & used_tuple_indices)58 static StatusOr<HloInstruction*> RemoveDeadTupleIndices(
59 HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) {
60 // Build up maps from the old/new to the new/old tuple indices.
61 std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
62 used_tuple_indices.end());
63 absl::c_sort(new_to_old_tuple_idx);
64
65 HloModule* module = while_op->GetModule();
66 HloComputation* computation = while_op->parent();
67 HloInstruction* while_init = while_op->mutable_operand(0);
68 HloComputation* while_cond = while_op->while_condition();
69 HloComputation* while_body = while_op->while_body();
70 HloInstruction* while_body_root = while_body->root_instruction();
71
72 auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
73
74 absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
75 for (int64_t new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
76 int64_t old_idx = new_to_old_tuple_idx[new_idx];
77 old_to_new_tuple_idx[old_idx] = new_idx;
78 VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
79 }
80
81 // Compute the shape of the while op after we remove the dead indices.
82 std::vector<Shape> new_while_tuple_elem_shapes;
83 new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
84 for (int64_t old_idx : new_to_old_tuple_idx) {
85 new_while_tuple_elem_shapes.push_back(
86 while_init->shape().tuple_shapes(old_idx));
87 }
88 Shape new_while_shape =
89 ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
90
91 // Returns a map from elements in the computation to new instructions which
92 // replace the old instructions after we remove unused elements from the while
93 // tuple.
94 auto make_while_computation_replacements = [&](const HloComputation* comp) {
95 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
96 replacements;
97
98 auto* param = comp->parameter_instruction(0);
99 replacements.emplace(param, HloInstruction::CreateParameter(
100 0, new_while_shape, param->name()));
101
102 // Materialize param's users, since we're about to add new ones below.
103 std::vector<HloInstruction*> materialized_users(param->users().begin(),
104 param->users().end());
105 for (const auto* user : materialized_users) {
106 // The while body root is handled separately.
107 if (user == while_body_root) {
108 continue;
109 }
110 CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
111 << user->ToString(print_no_metadata);
112
113 int64_t old_idx = user->tuple_index();
114 auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
115 if (new_idx_iter != old_to_new_tuple_idx.end()) {
116 // This is a GTE of an index that survives. Replace it.
117 replacements.emplace(
118 user, HloInstruction::CreateGetTupleElement(user->shape(), param,
119 new_idx_iter->second));
120 } else {
121 // This is a GTE of an index that we've removed. Remove it from the
122 // cloned computation.
123 CHECK(user->user_count() == 0 ||
124 user->user_count() == 1 &&
125 user->users().front() == while_body_root)
126 << "Instruction " << user->ToString(print_no_metadata)
127 << " should be unused (except by root of while body), but has "
128 "users: {"
129 << absl::StrJoin(user->users(), ", ",
130 [&](string* out, const HloInstruction* instr) {
131 absl::StrAppend(
132 out, instr->ToString(print_no_metadata));
133 })
134 << "}";
135
136 replacements.emplace(user, nullptr);
137 }
138 }
139 return replacements;
140 };
141
142 // Create the new while condition, body, and init value.
143 std::unique_ptr<HloComputation> new_while_cond =
144 while_cond->CloneWithReplacements(
145 make_while_computation_replacements(while_cond));
146
147 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
148 while_body_replacements = make_while_computation_replacements(while_body);
149 std::vector<HloInstruction*> new_while_body_root_elems;
150 new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
151 for (int64_t old_idx : new_to_old_tuple_idx) {
152 new_while_body_root_elems.push_back(
153 while_body_root->mutable_operand(old_idx));
154 }
155 while_body_replacements.emplace(
156 while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
157 std::unique_ptr<HloComputation> new_while_body =
158 while_body->CloneWithReplacements(std::move(while_body_replacements));
159
160 // Add a new while_init instruction that repackages the old while_init
161 // instruction's elements. We rely on the AlgebraicSimplifier and DCE to
162 // clean this up in the common case where while_init is a tuple op. (It's
163 // definitely tuple-shaped, but it's not necessarily a tuple op.)
164 std::vector<HloInstruction*> new_while_init_elems;
165 new_while_init_elems.reserve(new_to_old_tuple_idx.size());
166 for (int64_t old_idx : new_to_old_tuple_idx) {
167 new_while_init_elems.push_back(
168 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
169 while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
170 }
171 auto* new_while_init = computation->AddInstruction(
172 HloInstruction::CreateTuple(new_while_init_elems));
173
174 // Create the new while op.
175 auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
176 new_while_shape,
177 module->AddEmbeddedComputation(std::move(new_while_cond)),
178 module->AddEmbeddedComputation(std::move(new_while_body)),
179 new_while_init));
180 CopyBackendConfig(while_op, new_while_op);
181 CopyFrontendAttributes(while_op, new_while_op);
182
183 // Create a tuple op that recreates the output of the old while op. That is,
184 // we transform to
185 //
186 // new_while_init while_init
187 // | |
188 // V |
189 // new_while |
190 // | |
191 // -------| |----
192 // V V
193 // new_tuple
194 // |
195 // V
196 // (orig. users of while op)
197 //
198 // The tuple simplifier will then simplify this if possible, removing
199 // new_tuple and while_init.
200 std::vector<HloInstruction*> new_tuple_elems;
201 const int64_t tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
202 for (int64_t old_idx = 0; old_idx < tuple_size; ++old_idx) {
203 auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
204 if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
205 int64_t gte_idx = new_tuple_idx_it->second;
206 new_tuple_elems.push_back(
207 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
208 new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
209 gte_idx)));
210 } else {
211 new_tuple_elems.push_back(
212 computation->AddInstruction(HloInstruction::CreateGetTupleElement(
213 while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
214 }
215 }
216 HloInstruction* new_tuple =
217 computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
218 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple));
219
220 return new_while_op;
221 }
222
223 // Tries to remove elements in a while loop's tuple that aren't used within the
224 // loop.
225 //
226 // Specifically, if a loop is tuple-shaped, and there exists some element of
227 // that tuple that is not used by the loop condition and is not used by the loop
228 // body except to pass it to the next iteration of the loop, then we can remove
229 // that element from the loop's tuples.
TryRemoveDeadWhileParams(HloInstruction * while_op)230 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
231 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
232
233 // Don't try this transformation if the while loop isn't removable, since if
234 // it succeeds ultimately we're going to have to replace the old while loop
235 // with a new one.
236 if (!while_op->parent()->IsSafelyRemovable(while_op)) {
237 VLOG(2) << "Can't remove dead parameters from non-removable while op.";
238 return false;
239 }
240
241 HloInstruction* while_init = while_op->mutable_operand(0);
242 HloComputation* while_cond = while_op->while_condition();
243 HloComputation* while_body = while_op->while_body();
244 HloInstruction* while_body_root = while_body->root_instruction();
245
246 if (!while_init->shape().IsTuple()) {
247 VLOG(2) << "While op's carried value isn't tuple shaped.";
248 return false;
249 }
250
251 if (while_body_root->opcode() != HloOpcode::kTuple) {
252 VLOG(2) << "While body's root is not a tuple(...) instruction.";
253 return false;
254 }
255
256 auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
257
258 // Bail if param0 of while_cond or while_body has users which aren't of type
259 // get-tuple-element.
260 for (const HloInstruction* instr : {while_body->parameter_instruction(0),
261 while_cond->parameter_instruction(0)}) {
262 for (const HloInstruction* user : instr->users()) {
263 if (user->opcode() != HloOpcode::kGetTupleElement) {
264 VLOG(2) << "Cowardly refusing to analyze while loop with "
265 << instr->ToString(print_no_metadata)
266 << " used by non-GTE instruction "
267 << user->ToString(print_no_metadata) << " in computation "
268 << instr->parent()->name();
269 return false;
270 }
271 }
272 }
273
274 const int64_t tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
275 if (tuple_size == 0) {
276 VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
277 "empty.";
278 return false;
279 }
280
281 absl::flat_hash_set<int64> used_tuple_indices;
282 for (HloComputation* comp : {while_body, while_cond}) {
283 // The HLO verifier ensures that while_input's shape matches while_init's
284 // shape, which we verified above is a tuple.
285 HloInstruction* while_input = comp->parameter_instruction(0);
286
287 for (const HloInstruction* user : while_input->users()) {
288 // This user doesn't count if it's only used by the while body's root, and
289 // the root places the tuple element into the same index of the tuple as
290 // it came from. That just amounts to us carrying the variable through
291 // the loop.
292 //
293 // Careful: HloInstruction::operand_index returns the first index the
294 // operand appears in, but it may appear more than once!
295 if (user->user_count() == 1 && user->users().front() == while_body_root &&
296 while_body_root->operand_index(user) == user->tuple_index() &&
297 absl::c_count(while_body_root->operands(), user) == 1) {
298 continue;
299 }
300
301 used_tuple_indices.insert(user->tuple_index());
302 if (used_tuple_indices.size() == tuple_size) {
303 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
304 << " uses all of its inputs; no simplification possible.";
305 return false;
306 }
307 }
308 }
309
310 absl::flat_hash_set<int64> used_indices_after_loop;
311 if (while_op == while_op->parent()->root_instruction()) {
312 for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
313 used_indices_after_loop.insert(i);
314 }
315 }
316 for (auto user : while_op->users()) {
317 if (user->opcode() != HloOpcode::kGetTupleElement) {
318 for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
319 used_indices_after_loop.insert(i);
320 }
321 break;
322 }
323 used_indices_after_loop.insert(user->tuple_index());
324 }
325 // If a tuple element is used after the loop but not passed unmodified from
326 // the while body's param0 through to the while body's root, count that
327 // element as "used", since removing that element would be observable.
328 for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
329 if (used_tuple_indices.contains(i) ||
330 !used_indices_after_loop.contains(i)) {
331 continue;
332 }
333
334 auto* operand = while_body_root->operand(i);
335 if (operand->opcode() != HloOpcode::kGetTupleElement ||
336 operand->operand(0) != while_body->parameter_instruction(0) ||
337 operand->tuple_index() != i) {
338 VLOG(2) << "Tuple index " << i
339 << " is not passed through loop body unmodified.";
340 used_tuple_indices.insert(i);
341
342 if (used_tuple_indices.size() == tuple_size) {
343 VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
344 << " uses all of its inputs; no simplification possible.";
345 return false;
346 }
347 }
348 }
349
350 // If we got here, used_tuple_indices.size() < tuple_size, meaning some
351 // elements of the loop's tuple aren't used by while_body or while_cond.
352 CHECK_LT(used_tuple_indices.size(), tuple_size);
353
354 VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
355 << " elements from tuple of "
356 << while_op->ToString(print_no_metadata);
357
358 TF_ASSIGN_OR_RETURN(while_op,
359 RemoveDeadTupleIndices(while_op, used_tuple_indices));
360
361 return true;
362 }
363
364 // This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes
365 // duplicates by replacing them with tuple_index, followed by a call to
366 // RemoveDeadTupleIndices.
TryRemoveRepeatedWhileTupleIndicesHelper(HloInstruction * while_op,const int64_t tuple_index,absl::flat_hash_set<int64> & duplicates)367 static StatusOr<HloInstruction*> TryRemoveRepeatedWhileTupleIndicesHelper(
368 HloInstruction* while_op, const int64_t tuple_index,
369 absl::flat_hash_set<int64>& duplicates) {
370 HloComputation* while_cond = while_op->while_condition();
371 HloComputation* while_body = while_op->while_body();
372 HloInstruction* while_init = while_op->mutable_operand(0);
373
374 VLOG(2) << "while_init " << while_init->ToString() << " operands "
375 << while_init->operand_count();
376 VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString()
377 << " operands " << while_body->root_instruction()->operand_count();
378
379 // Change the loop body and condition such that uses of the duplicates are
380 // replaced with the original tuple element.
381 for (HloComputation* comp : {while_body, while_cond}) {
382 auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
383 comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index),
384 comp->parameter_instruction(0), tuple_index));
385
386 std::vector<HloInstruction*> instrs_to_replace;
387 for (auto* instr : comp->instructions()) {
388 if (instr->opcode() == HloOpcode::kGetTupleElement &&
389 duplicates.contains(instr->tuple_index()) &&
390 instr->operand(0) == comp->parameter_instruction(0)) {
391 instrs_to_replace.push_back(instr);
392 }
393 }
394
395 for (auto instr : instrs_to_replace) {
396 TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get));
397 }
398 }
399
400 // We know which tuple indices are useful; i.e, those which aren't duplicates.
401 absl::flat_hash_set<int64> used_tuple_indices;
402 for (int index = 0; index < while_init->shape().tuple_shapes_size();
403 ++index) {
404 if (!duplicates.count(index)) {
405 used_tuple_indices.insert(index);
406 }
407 }
408
409 // Remove the duplicate tuple elements.
410 TF_ASSIGN_OR_RETURN(while_op,
411 RemoveDeadTupleIndices(while_op, used_tuple_indices));
412
413 return while_op;
414 }
415
416 // If the while loop init passes the same values to several tuple indices, and
417 // if the body keeps on passing them through, we can remove the duplicates.
TryRemoveRepeatedWhileTupleIndices(HloInstruction * while_op)418 static StatusOr<bool> TryRemoveRepeatedWhileTupleIndices(
419 HloInstruction* while_op) {
420 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
421
422 int index_to_investigate = 0;
423 // Don't try this transformation if the while loop isn't removable, since if
424 // it succeeds ultimately we're going to have to replace the old while loop
425 // with a new one.
426 if (!while_op->parent()->IsSafelyRemovable(while_op)) {
427 VLOG(2) << "Can't remove dead parameters from non-removable while op.";
428 return false;
429 }
430
431 HloInstruction* while_init = while_op->mutable_operand(0);
432 HloComputation* while_cond = while_op->while_condition();
433 HloComputation* while_body = while_op->while_body();
434 HloInstruction* while_body_root = while_body->root_instruction();
435
436 if (!while_init->shape().IsTuple()) {
437 VLOG(2) << "While op's carried value isn't tuple shaped.";
438 return false;
439 }
440
441 bool changed = false;
442 while (index_to_investigate < while_init->shape().tuple_shapes_size()) {
443 if (!while_init->shape().IsTuple() ||
444 while_init->opcode() != HloOpcode::kTuple) {
445 VLOG(2) << "While op's carried value isn't tuple shaped.";
446 return false;
447 }
448
449 if (while_body_root->opcode() != HloOpcode::kTuple) {
450 VLOG(2) << "While body's root is not a tuple(...) instruction.";
451 return false;
452 }
453
454 auto& while_shape = while_init->shape();
455 VLOG(2) << "Iterating " << index_to_investigate;
456
457 absl::flat_hash_set<int64> duplicates;
458 auto* pivot_init_elem = while_init->operand(index_to_investigate);
459 auto* pivot_body_elem = while_body_root->operand(index_to_investigate);
460 if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement &&
461 pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) {
462 if (pivot_body_elem->tuple_index() != index_to_investigate) {
463 VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() "
464 << pivot_body_elem->tuple_index() << " index_to_investigate "
465 << index_to_investigate;
466 index_to_investigate++;
467 continue;
468 }
469 } else {
470 index_to_investigate++;
471 continue;
472 }
473
474 // Look from index_to_investigate onwards to see if it is repeated.
475 for (int64_t i = index_to_investigate + 1;
476 i < while_shape.tuple_shapes_size(); ++i) {
477 auto* init_elem = while_init->operand(i);
478 auto* body_elem = while_body_root->operand(i);
479 if (body_elem->opcode() == HloOpcode::kGetTupleElement &&
480 body_elem->operand(0) == while_body->parameter_instruction(0)) {
481 if (body_elem->tuple_index() != i) {
482 VLOG(2) << "Mismatch between body_elem->tuple_index() "
483 << body_elem->tuple_index() << " i " << i;
484 continue;
485 }
486 } else {
487 continue;
488 }
489
490 if (pivot_init_elem == init_elem) {
491 VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem "
492 << pivot_init_elem->ToString();
493 VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem "
494 << pivot_body_elem->ToString();
495 duplicates.insert(i);
496 }
497 }
498
499 // If duplicates are found, call the helper to remove them.
500 if (!duplicates.empty()) {
501 VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init "
502 << pivot_init_elem->ToString();
503 TF_ASSIGN_OR_RETURN(while_op,
504 TryRemoveRepeatedWhileTupleIndicesHelper(
505 while_op, index_to_investigate, duplicates));
506 changed = true;
507 VLOG(2) << "Changed while_op " << while_op->ToString()
508 << " while_op operand count " << while_op->operand_count();
509 // Update the while loop variables so we can continue looking for
510 // duplicates of a different index.
511 while_init = while_op->mutable_operand(0);
512 while_cond = while_op->while_condition();
513 while_body = while_op->while_body();
514 while_body_root = while_body->root_instruction();
515 }
516 index_to_investigate++;
517 }
518
519 return changed;
520 }
521
522 // Removes each loop parameter (i.e. member of the while loop tuple) that is a
523 // constant and is the same in the while loop body and the while loop init.
TryRemoveConstantParams(HloInstruction * while_op)524 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
525 HloModule* module = while_op->GetModule();
526 HloComputation* computation = while_op->parent();
527 auto* while_init = while_op->mutable_operand(0);
528 auto* while_body = while_op->while_body();
529 auto* while_cond = while_op->while_condition();
530 auto* while_body_root = while_body->root_instruction();
531 if (while_init->opcode() != HloOpcode::kTuple ||
532 while_body_root->opcode() != HloOpcode::kTuple) {
533 return false;
534 }
535
536 TF_RET_CHECK(while_cond->num_parameters() == 1);
537 TF_RET_CHECK(while_body->num_parameters() == 1);
538 TF_RET_CHECK(
539 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
540
541 absl::flat_hash_set<int64> constant_tuple_indices;
542 const auto& while_shape = while_init->shape();
543 for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
544 auto* init_elem = while_init->operand(i);
545 auto* body_elem = while_body_root->operand(i);
546 if (init_elem->opcode() == HloOpcode::kConstant &&
547 body_elem->opcode() == HloOpcode::kConstant &&
548 init_elem->literal() == body_elem->literal()) {
549 constant_tuple_indices.insert(i);
550 }
551 }
552
553 if (constant_tuple_indices.empty()) {
554 return false;
555 }
556
557 // OK, we found some constant elements of the while parameter! Eliminate
558 // them.
559 std::vector<Shape> new_while_shape_elems;
560 for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
561 if (!constant_tuple_indices.count(i)) {
562 new_while_shape_elems.push_back(while_shape.tuple_shapes(i));
563 }
564 }
565 Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems);
566
567 // `new_instrs` holds instructions created outside of a computation for
568 // cloning. Elements added here just need to live until the end of the
569 // relevant CloneWithReplacement call.
570 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
571 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
572 new_instrs.push_back(std::move(instr));
573 return new_instrs.back().get();
574 };
575
576 // Returns a new tuple without the elements of constant_tuple_indices.
577 auto remove_constant_elems = [&](HloInstruction* instr) {
578 CHECK(ShapeUtil::Compatible(instr->shape(), while_shape));
579
580 std::vector<HloInstruction*> tuple_elems;
581 for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
582 if (!constant_tuple_indices.count(i)) {
583 tuple_elems.push_back(
584 add_new_instr(HloInstruction::CreateGetTupleElement(
585 while_shape.tuple_shapes(i), instr, i)));
586 }
587 }
588 return HloInstruction::CreateTuple(tuple_elems);
589 };
590
591 auto add_constant_elems = [&](HloInstruction* instr) {
592 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
593
594 std::vector<HloInstruction*> tuple_elems;
595 int64_t j = 0;
596 for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
597 if (constant_tuple_indices.count(i)) {
598 tuple_elems.push_back(while_init->mutable_operand(i));
599 } else {
600 tuple_elems.push_back(
601 add_new_instr(HloInstruction::CreateGetTupleElement(
602 while_shape.tuple_shapes(i), instr, j)));
603 ++j;
604 }
605 }
606 return HloInstruction::CreateTuple(tuple_elems);
607 };
608
609 // Special case: constant_tuple_indices covers the whole while parameter, so
610 // the new while shape is the empty tuple. In this case, the value of the
611 // while loop is simply equal to the value of `init`.
612 //
613 // It's unfortunate to special-case this, but it's simpler than the
614 // alternative. The problem is that if our while parameter has no
615 // non-constant elems, the tuple returned by `add_constant_elems` won't depend
616 // on instr (the loop body/cond parameter), and therefore
617 // CloneWithReplacementPairs will *leave the parameter out entirely*, creating
618 // invalid HLO.
619 if (ShapeUtil::IsEmptyTuple(new_while_shape)) {
620 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init));
621 return true;
622 }
623
624 std::unique_ptr<HloComputation> new_while_cond =
625 while_cond->CloneWithReplacementPairs({
626 while_cond->parameter_instruction(0),
627 add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
628 0, new_while_shape,
629 while_cond->parameter_instruction(0)->name()))),
630 });
631
632 std::unique_ptr<HloComputation> new_while_body =
633 while_body->CloneWithReplacementPairs(
634 {
635 while_body->parameter_instruction(0),
636 add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
637 0, new_while_shape,
638 while_cond->parameter_instruction(0)->name()))),
639 },
640 {
641 while_body->root_instruction(),
642 remove_constant_elems(
643 add_new_instr(while_body->root_instruction()->Clone())),
644 });
645
646 // Create the final while loop, and add any new instructions created to
647 // `computation`.
648 new_instrs.clear();
649 auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
650 new_while_shape,
651 module->AddEmbeddedComputation(std::move(new_while_cond)),
652 module->AddEmbeddedComputation(std::move(new_while_body)),
653 add_new_instr(remove_constant_elems(while_init))));
654 CopyBackendConfig(while_op, new_while_op);
655 CopyFrontendAttributes(while_op, new_while_op);
656 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
657 while_op, add_constant_elems(new_while_op)));
658 for (auto& instr : new_instrs) {
659 computation->AddInstruction(std::move(instr));
660 }
661 return true;
662 }
663
664 // Tries to remove a while loop from the graph.
665 //
666 // - Loops with trip count of 0 can be replaced by the loop's "init" value.
667 // - Loops with trip count of 1 can be replaced by the loop's body, with the
668 // loop itself removed.
669 //
670 // Returns true if it made a change to the graph.
TryRemoveWhileLoop(HloInstruction * while_op)671 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
672 // Cowardly refuse to remove loops that are not removable. In practice, this
673 // means that we can't remove loops that have control predecessors/successors.
674 if (!while_op->parent()->IsSafelyRemovable(while_op)) {
675 VLOG(2) << "Not attempting to remove while loop that is not removable: "
676 << while_op->ToShortString();
677 return false;
678 }
679
680 // Refuse to remove while loops with a condition that contain side-effects,
681 // because removing a while loop is tantamount to removing its condition.
682 //
683 // TODO(jlebar): This is conservative: We could instead just run the while
684 // condition once (trip-count == 0) or twice (trip-count == 1).
685 if (while_op->while_condition()->HasSideEffect()) {
686 VLOG(2) << "Not attempting to remove while loop whose condition contains "
687 "side-effecting instructions: "
688 << while_op->ToShortString();
689 return false;
690 }
691
692 // Remove while loops with static trip count of 0.
693 optional<int64> trip_count =
694 ComputeWhileLoopTripCount(while_op, /*max_brute_force_iters=*/1);
695 if (trip_count && *trip_count == 0) {
696 // The loop never executes, so the value of the loop is the value of its
697 // "init" operand.
698 auto computation = while_op->parent();
699
700 // Remove while_op (i.e., call ReplaceInstruction rather than
701 // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in
702 // a loop without an intervening DCE, we don't try to re-remove the loop.
703 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
704 while_op, while_op->mutable_operand(0)));
705 return true;
706 }
707
708 // Transform while loops with static trip count of 1 into a call op, then
709 // inline the call.
710 const auto& attrs = while_op->frontend_attributes().map();
711 bool skip_trip_count_one_simplification =
712 attrs.contains("skip-simplify-while-loops/trip-count-one") &&
713 (attrs.at("skip-simplify-while-loops/trip-count-one") == "true");
714 if (trip_count && *trip_count == 1 && !skip_trip_count_one_simplification) {
715 // Do not simplify the loop away when there is a side-effectful op,
716 // otherwise the infeed op may not inherit the data dependency from
717 // the while loop.
718 //
719 // Example: while_body (param_a) {
720 // param_a = parameter(0)
721 // infeed2 = infeed()
722 // }
723 //
724 // infeed1 = ...
725 // while = while(infeed1), body=while_body // infeed2 has implicit
726 // dependency on infeed1.
727 //
728 // After simplification:
729 //
730 // infeed1 = ...
731 // infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
732 // // can be scheduled after infeed2.
733 //
734 bool has_side_effects = absl::c_any_of(
735 while_op->called_computations(), [](const HloComputation* computation) {
736 return computation->HasSideEffect();
737 });
738 if (!has_side_effects) {
739 auto computation = while_op->parent();
740 auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
741 while_op->shape(), while_op->operands(), while_op->while_body()));
742 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
743 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
744 CallInliner::Inline(call_op));
745 (void)inlined_instructions_map;
746 return true;
747 } else {
748 VLOG(2) << "Not attempting to simplify while loop because it contains a "
749 "side-effecting node: "
750 << while_op->ToShortString();
751 }
752 }
753 return false;
754 }
755
TryPropagateConstant(HloInstruction * while_op)756 static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
757 auto while_init = while_op->operand(0);
758 if (while_init->opcode() != HloOpcode::kTuple) {
759 return false;
760 }
761
762 auto while_body = while_op->while_body();
763 auto while_body_root = while_body->root_instruction();
764 if (while_body_root->opcode() != HloOpcode::kTuple) {
765 return false;
766 }
767
768 auto while_body_param = while_body->parameter_instruction(0);
769 const HloInstruction::InstructionVector& root_operands =
770 while_body_root->operands();
771
772 // Find the loop invariant tuple elements with scalar constant init value and
773 // build a map from the tuple element index to the constant value. Limit this
774 // to scalar constant values because propagating array constants can regress
775 // performance by forcing us to copy constants.
776 absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
777 for (int i = 0; i < root_operands.size(); i++) {
778 const HloInstruction* init_tuple_elem = nullptr;
779 if (Match(root_operands[i],
780 m::GetTupleElement(m::Op().Is(while_body_param), i)
781 .WithShape(m::Shape().IsScalar())) &&
782 Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
783 VLOG(3) << "Found loop invariant tuple element " << i << " "
784 << init_tuple_elem->ToString();
785 index_to_constant[i] = init_tuple_elem;
786 }
787 }
788
789 if (index_to_constant.empty()) {
790 return false;
791 }
792
793 // Replace the use of each constant tuple element in the loop_condition and
794 // loop_body with the corresponding constant value.
795 auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
796 HloInstruction* param = computation->parameter_instruction(0);
797 bool changed = false;
798 for (auto instr : param->users()) {
799 // Since only a while-loop with a tuple result reaches here, we can safely
800 // assume that `param` is a tuple and the first operand of the
801 // GetTupleElement instruction is a use of `param`.
802 if (instr->opcode() == HloOpcode::kGetTupleElement) {
803 VLOG(3) << "tuple index " << instr->tuple_index() << " "
804 << instr->ToString();
805 auto iter = index_to_constant.find(instr->tuple_index());
806 if (iter != index_to_constant.end()) {
807 const HloInstruction* hlo_constant = (*iter).second;
808 VLOG(3) << "Replace use of " << instr->ToString() << " with "
809 << hlo_constant->ToString();
810 TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
811 computation->AddInstruction(hlo_constant->Clone())));
812 changed = true;
813 }
814 }
815 }
816 return changed;
817 };
818
819 TF_ASSIGN_OR_RETURN(bool changed_cond,
820 propagate_constant(while_op->while_condition()));
821 TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
822
823 return changed_cond || changed_body;
824 }
825
826 // Converts a flat list of instructions into a tuple of the desired shape. For
827 // example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns
828 // a tuple of value ((A, B), C).
829 //
830 // desired_shape must be a tuple. (This precondition allows us to return a
831 // unique_ptr rather than a raw ptr.)
UnflattenTupleInstr(absl::Span<HloInstruction * > instrs,const Shape & desired_shape,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)832 static std::unique_ptr<HloInstruction> UnflattenTupleInstr(
833 absl::Span<HloInstruction*> instrs, const Shape& desired_shape,
834 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
835 CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape);
836
837 // For each child shape in `desired_shape`, slice out the correct number of
838 // `instrs` and call UnflattenTupleInstr recursively. At each step we remove
839 // elements from `instrs` so that it only contains instructions we have not
840 // yet processed.
841 std::vector<HloInstruction*> elems;
842 for (int64_t i = 0; i < desired_shape.tuple_shapes_size(); ++i) {
843 const Shape& subshape = desired_shape.tuple_shapes(i);
844 if (!subshape.IsTuple()) {
845 elems.push_back(instrs[0]);
846 instrs.remove_prefix(1);
847 continue;
848 }
849
850 // Count the number of leaf nodes underneath desired_shape[i].
851 int64_t num_leaves = 0;
852 ShapeUtil::ForEachSubshape(
853 subshape, [&](const Shape& s, const ShapeIndex& /*index*/) {
854 if (!s.IsTuple()) {
855 ++num_leaves;
856 }
857 });
858
859 std::unique_ptr<HloInstruction> subinstr =
860 UnflattenTupleInstr(instrs.subspan(0, num_leaves),
861 desired_shape.tuple_shapes(i), new_instrs);
862 elems.push_back(subinstr.get());
863 new_instrs->push_back(std::move(subinstr));
864 instrs.remove_prefix(num_leaves);
865 }
866 return HloInstruction::CreateTuple(elems);
867 }
868
869 // Builds a vector whose elements are the values in the flattened tuple for
870 // `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the
871 // vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C).
GetFlatTupleElems(HloInstruction * instr,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)872 static std::vector<HloInstruction*> GetFlatTupleElems(
873 HloInstruction* instr,
874 std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
875 const auto& shape = instr->shape();
876 if (!shape.IsTuple()) {
877 return {instr};
878 }
879 std::vector<HloInstruction*> elems;
880 for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
881 const Shape& subshape = shape.tuple_shapes(i);
882 new_instrs->push_back(
883 HloInstruction::CreateGetTupleElement(subshape, instr, i));
884 auto* gte = new_instrs->back().get();
885 auto flattened_subshape = GetFlatTupleElems(gte, new_instrs);
886 elems.insert(elems.end(), flattened_subshape.begin(),
887 flattened_subshape.end());
888 }
889 return elems;
890 }
891
TryFlattenNestedTuples(HloInstruction * while_op)892 static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
893 HloModule* module = while_op->GetModule();
894 HloComputation* computation = while_op->parent();
895 auto* while_init = while_op->mutable_operand(0);
896 auto* while_body = while_op->while_body();
897 auto* while_cond = while_op->while_condition();
898 auto* while_body_root = while_body->root_instruction();
899 if (while_init->opcode() != HloOpcode::kTuple ||
900 while_body_root->opcode() != HloOpcode::kTuple) {
901 return false;
902 }
903
904 TF_RET_CHECK(while_cond->num_parameters() == 1);
905 TF_RET_CHECK(while_body->num_parameters() == 1);
906 TF_RET_CHECK(
907 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
908 Shape while_shape = while_init->shape();
909 if (!ShapeUtil::IsNestedTuple(while_shape)) {
910 return false;
911 }
912
913 std::vector<Shape> flattened_shape_elems;
914 ShapeUtil::ForEachSubshape(while_shape,
915 [&](const Shape& s, const ShapeIndex& /*index*/) {
916 if (!s.IsTuple()) {
917 flattened_shape_elems.push_back(s);
918 }
919 });
920 Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems);
921
922 // `new_instrs` holds instructions created outside of a computation for
923 // cloning. Elements added here just need to live until the end of the
924 // relevant CloneWithReplacement call.
925 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
926 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
927 new_instrs.push_back(std::move(instr));
928 return new_instrs.back().get();
929 };
930
931 auto nested = [&](HloInstruction* instr) {
932 std::vector<HloInstruction*> gtes;
933 const Shape& flat_shape = instr->shape();
934 for (int64_t i = 0; i < flat_shape.tuple_shapes_size(); ++i) {
935 gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement(
936 flat_shape.tuple_shapes(i), instr, i)));
937 }
938 auto nested_instr =
939 UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs);
940 CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape))
941 << ShapeUtil::HumanString(nested_instr->shape()) << " vs "
942 << ShapeUtil::HumanString(while_shape);
943 return nested_instr;
944 };
945
946 auto flattened = [&](HloInstruction* instr) {
947 return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs));
948 };
949
950 // Create a new while-condition computation, where parameter 0 has flat shape
951 // but all uses of it go through the nested shape.
952 std::unique_ptr<HloComputation> new_while_cond =
953 while_cond->CloneWithReplacementPairs({
954 while_cond->parameter_instruction(0),
955 nested(add_new_instr(HloInstruction::CreateParameter(
956 0, flattened_shape,
957 while_cond->parameter_instruction(0)->name()))),
958 });
959
960 // Create a new while-body computation, where parameter 0 has a flat shape and
961 // all uses of it go through the nested shape, and where the root has a flat
962 // shape constructed from the old nested root.
963 std::unique_ptr<HloComputation> new_while_body =
964 while_body->CloneWithReplacementPairs(
965 {
966 while_body->parameter_instruction(0),
967 nested(add_new_instr(HloInstruction::CreateParameter(
968 0, flattened_shape,
969 while_body->parameter_instruction(0)->name()))),
970 },
971 {
972 while_body->root_instruction(),
973 flattened(add_new_instr(while_body->root_instruction()->Clone())),
974 });
975
976 // Create the final while loop, and add any new instructions created to
977 // `computation`.
978 new_instrs.clear();
979 auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
980 flattened_shape,
981 module->AddEmbeddedComputation(std::move(new_while_cond)),
982 module->AddEmbeddedComputation(std::move(new_while_body)),
983 computation->AddInstruction(flattened(while_init))));
984 CopyBackendConfig(while_op, new_while_op);
985 CopyFrontendAttributes(while_op, new_while_op);
986 TF_RETURN_IF_ERROR(
987 computation->ReplaceWithNewInstruction(while_op, nested(new_while_op)));
988 for (auto& instr : new_instrs) {
989 computation->AddInstruction(std::move(instr));
990 }
991 return true;
992 }
993
994 // Tries to merge loop induction variables of a given type.
995 //
996 // In this pass we're only concerned with elements of the loop's tuple that
997 // are effective-scalars of type `elem_ty`. Some terminology:
998 //
999 // - The trip counter is the first element of the loop's tuple that starts at
1000 // 0 and does x++ on each iteration.
1001 //
1002 // - An induction variable is an element of the loop's tuple that is not the
1003 // trip counter and does `x += <constant>` on each iteration of the loop.
1004 // Negative constants are OK.
1005 //
1006 // This pass adds a trip counter if one isn't already present, then replaces
1007 // each induction variable with
1008 //
1009 // <initial_value> + <trip_count> * <constant>.
1010 //
1011 // This reduces the number of scalar operations in the loop, which is important
1012 // e.g. on GPUs, where each scalar operation is nontrivially expensive because
1013 // it's a separate kernel launch.
1014 //
1015 // Returns the new loop if a change was made, or null if no change was made.
1016 // Note that the new loop is not a valid replacement for the old loop; it may
1017 // need to be wrapped in a tuple that changes its shape. We return the loop
1018 // itself so that you can call TryMergeInductionVariables in a loop, once for
1019 // each integral type elem_ty.
TryMergeInductionVariables(HloInstruction * while_op,PrimitiveType elem_ty)1020 static StatusOr<HloInstruction*> TryMergeInductionVariables(
1021 HloInstruction* while_op, PrimitiveType elem_ty) {
1022 CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty);
1023 HloModule* module = while_op->GetModule();
1024 HloComputation* computation = while_op->parent();
1025 auto* while_init = while_op->mutable_operand(0);
1026 auto* while_body = while_op->while_body();
1027 auto* while_cond = while_op->while_condition();
1028 auto* while_body_root = while_body->root_instruction();
1029 if (while_init->opcode() != HloOpcode::kTuple ||
1030 while_body_root->opcode() != HloOpcode::kTuple) {
1031 return nullptr;
1032 }
1033
1034 TF_RET_CHECK(while_cond->num_parameters() == 1);
1035 TF_RET_CHECK(while_body->num_parameters() == 1);
1036 TF_RET_CHECK(
1037 ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
1038 Shape while_shape = while_init->shape();
1039
1040 // The tuple index of the trip counter, if one is present.
1041 absl::optional<int64> trip_counter;
1042 // Maps the tuple index of each induction variable to its constant increment.
1043 absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
1044 for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
1045 HloInstruction* constant;
1046 if (!Match(while_body_root->mutable_operand(i),
1047 m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
1048 m::ConstantScalar(&constant))
1049 .WithShape(m::Shape().WithElementType(elem_ty)))) {
1050 continue;
1051 }
1052 if (!trip_counter && constant->literal().IsAll(1) &&
1053 while_init->operand(i)->IsConstant() &&
1054 while_init->operand(i)->literal().IsAll(0)) {
1055 VLOG(10) << "Found existing trip counter at index " << i;
1056 trip_counter = i;
1057 } else {
1058 VLOG(10) << "Found induction variable at index " << i;
1059 induction_vars.emplace(i, Cast<HloConstantInstruction>(constant));
1060 }
1061 }
1062
1063 // There's only something to simplify if we can either:
1064 //
1065 // - combine one or more induction vars with an existing trip counter, or
1066 // - replace two or more induction variables with a new trip counter.
1067 //
1068 // Put another way, there's only something to simplify if the number of
1069 // induction vars plus the number of existing trip counters (0 or 1) is >= 2.
1070 if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) {
1071 return nullptr;
1072 }
1073
1074 // OK, we're going to do the transformation! Set up some helpers.
1075
1076 // `new_instrs` holds instructions created outside of a computation for
1077 // cloning. Elements added here just need to live until the end of the
1078 // relevant CloneWithReplacement call.
1079 std::vector<std::unique_ptr<HloInstruction>> new_instrs;
1080 auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
1081 new_instrs.push_back(std::move(instr));
1082 return new_instrs.back().get();
1083 };
1084
1085 auto add_binary_op = [&](const Shape& shape, HloOpcode opcode,
1086 HloInstruction* lhs, HloInstruction* rhs) {
1087 // Reshape lhs/rhs to the output shape if necessary. This deals with the
1088 // fact that induction variables need only be effective scalars, not true
1089 // scalars.
1090 if (!ShapeUtil::Compatible(shape, lhs->shape())) {
1091 lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs));
1092 }
1093 if (!ShapeUtil::Compatible(shape, rhs->shape())) {
1094 rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs));
1095 }
1096 return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs));
1097 };
1098
1099 auto add_gte = [&](HloInstruction* src, int64_t idx) {
1100 return add_new_instr(HloInstruction::CreateGetTupleElement(
1101 src->shape().tuple_shapes(idx), src, idx));
1102 };
1103
1104 // Our new while loop will have the same shape as the old while loop, except
1105 // we'll add a trip counter to the end if it wasn't originally present.
1106 Shape new_while_shape = while_shape;
1107 bool added_trip_counter = false;
1108 if (!trip_counter) {
1109 VLOG(10) << "Adding new trip counter to end of loop's tuple.";
1110 trip_counter = new_while_shape.tuple_shapes_size();
1111 *new_while_shape.add_tuple_shapes() =
1112 ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{});
1113 added_trip_counter = true;
1114 }
1115
1116 // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with
1117 // shape `while_body->shape()` and where the induction variables are "reified"
1118 // (i.e. they have value <init> + <counter> * <constant>).
1119 auto convert_to_old_form = [&](HloInstruction* instr) {
1120 CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
1121 std::vector<HloInstruction*> tuple_elems;
1122 for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1123 const auto& elem_shape = while_shape.tuple_shapes(i);
1124 if (!induction_vars.count(i)) {
1125 tuple_elems.push_back(add_gte(instr, i));
1126 continue;
1127 }
1128 tuple_elems.push_back(add_binary_op(
1129 elem_shape, HloOpcode::kAdd, add_gte(instr, i),
1130 add_binary_op(elem_shape, HloOpcode::kMultiply,
1131 add_gte(instr, *trip_counter),
1132 add_new_instr(induction_vars.at(i)->Clone()))));
1133 }
1134 return HloInstruction::CreateTuple(tuple_elems);
1135 };
1136
1137 // Converts `root` into a tuple of the "new" form -- that is, to a tuple with
1138 // shape `new_while_shape` and where the induction variables (but not trip
1139 // counters) are replaced with their unchanging <loop_body_param> values.
1140 auto convert_to_new_form = [&](HloInstruction* old_root,
1141 HloParameterInstruction* loop_body_param) {
1142 CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape));
1143 std::vector<HloInstruction*> tuple_elems;
1144
1145 // In the new form, induction variables come from `init`, everything else
1146 // (including the trip counter if it's not one we created ourselves) comes
1147 // from the `root` tuple unmodified.
1148 for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1149 tuple_elems.push_back(
1150 add_gte((induction_vars.count(i) ? loop_body_param : old_root), i));
1151 }
1152 // If we created a trip counter ourselves, add 1 to it in the next
1153 // iteration.
1154 if (added_trip_counter) {
1155 tuple_elems.push_back(add_binary_op(
1156 new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd,
1157 add_gte(loop_body_param, *trip_counter),
1158 add_new_instr(
1159 HloInstruction::CreateConstant(LiteralUtil::One(elem_ty)))));
1160 }
1161
1162 return HloInstruction::CreateTuple(tuple_elems);
1163 };
1164
1165 // Creates a new init tuple, which is the same as the old init tuple except if
1166 // we added a trip counter, it's set to 0.
1167 auto get_new_while_init = [&](HloInstruction* init) {
1168 CHECK(ShapeUtil::Compatible(init->shape(), while_shape));
1169 if (!added_trip_counter) {
1170 return init;
1171 }
1172 std::vector<HloInstruction*> tuple_elems;
1173 for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1174 tuple_elems.push_back(add_gte(init, i));
1175 }
1176 tuple_elems.push_back(add_new_instr(
1177 HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty))));
1178 return add_new_instr(HloInstruction::CreateTuple(tuple_elems));
1179 };
1180
1181 std::unique_ptr<HloComputation> new_while_cond =
1182 while_cond->CloneWithReplacementPairs({
1183 while_cond->parameter_instruction(0),
1184 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1185 0, new_while_shape,
1186 while_cond->parameter_instruction(0)->name()))),
1187 });
1188
1189 // Creating the new while body proceeds in two steps. First we convert the
1190 // users of the parameter to the old form. Then as a second
1191 // CloneWithReplacement operation we convert the root to the new form. We
1192 // have to do this in two steps because the new root needs to use the new
1193 // param0, and during the first clone operation, only the *old-form* param0 is
1194 // accessible.
1195 //
1196 // We have to add temp_new_while_body to the module because cloning a
1197 // computation touches the module (to get its NameUniquer).
1198 HloComputation* temp_new_while_body =
1199 module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({
1200 while_body->parameter_instruction(0),
1201 convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1202 0, new_while_shape,
1203 while_body->parameter_instruction(0)->name()))),
1204 }));
1205 std::unique_ptr<HloComputation> new_while_body =
1206 temp_new_while_body->CloneWithReplacementPairs({
1207 temp_new_while_body->root_instruction(),
1208 convert_to_new_form(
1209 add_new_instr(temp_new_while_body->root_instruction()->Clone()),
1210 Cast<HloParameterInstruction>(
1211 temp_new_while_body->parameter_instruction(0))),
1212 });
1213 TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body));
1214
1215 // Create the final while loop, and add any new instructions created to
1216 // `computation`.
1217 new_instrs.clear();
1218 auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile(
1219 new_while_shape,
1220 module->AddEmbeddedComputation(std::move(new_while_cond)),
1221 module->AddEmbeddedComputation(std::move(new_while_body)),
1222 get_new_while_init(while_init)));
1223 CopyBackendConfig(while_op, new_while);
1224 CopyFrontendAttributes(while_op, new_while);
1225 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
1226 while_op, convert_to_old_form(new_while)));
1227 for (auto& instr : new_instrs) {
1228 computation->AddInstruction(std::move(instr));
1229 }
1230 return new_while;
1231 }
1232
Run(HloModule * module)1233 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
1234 XLA_VLOG_LINES(3,
1235 "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
1236 bool changed = false;
1237
1238 // Gather all the while ops in our module. We do this ahead of time so we
1239 // don't have to worry about mutating the lists of computations or
1240 // instructions while we iterate.
1241 std::vector<HloInstruction*> while_ops;
1242 for (auto* comp : module->computations()) {
1243 for (auto* instr : comp->instructions()) {
1244 if (instr->opcode() == HloOpcode::kWhile) {
1245 while_ops.push_back(instr);
1246 }
1247 }
1248 }
1249
1250 for (HloInstruction* while_op : while_ops) {
1251 // We can't remove while loops that contain send/recv nodes, because we rely
1252 // on the particular loop structure around the node matching on the send and
1253 // recv sides. Other while simplifications require us to remove the loop
1254 // and replace it with a new one, so we can't do that either.
1255 if (ContainsInstrWithOpcode(while_op->while_body(),
1256 {HloOpcode::kSend, HloOpcode::kSendDone,
1257 HloOpcode::kRecv, HloOpcode::kRecvDone}) ||
1258 ContainsInstrWithOpcode(while_op->while_condition(),
1259 {HloOpcode::kSend, HloOpcode::kSendDone,
1260 HloOpcode::kRecv, HloOpcode::kRecvDone})) {
1261 VLOG(2) << "Not attempting to simplify while loop because it contains a "
1262 "send/recv node: "
1263 << while_op->ToShortString();
1264 continue;
1265 }
1266
1267 TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
1268 changed |= result;
1269
1270 TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
1271 changed |= result;
1272
1273 if (result) {
1274 // Don't continue simplifying after successfully removing the while loop
1275 // -- that would result in use-after-free nastiness.
1276 continue;
1277 }
1278
1279 // TODO(b/119281462): Cowardly refuse to perform any of the following
1280 // optimizations in the presence of kDomain instructions. It seems that
1281 // modifying a while loop's tuple doesn't work when kDomain is present.
1282 if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) ||
1283 ContainsInstrWithOpcode(while_op->while_condition(),
1284 {HloOpcode::kDomain})) {
1285 continue;
1286 }
1287
1288 // Each of the optimizations below modifies the while loop itself if it's
1289 // successful, meaning that `while_op` is no longer valid after one of these
1290 // transformations returns true.
1291
1292 TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op));
1293 changed |= result;
1294 if (result) {
1295 continue;
1296 }
1297
1298 TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
1299 changed |= result;
1300 if (result) {
1301 continue;
1302 }
1303
1304 TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
1305
1306 changed |= result;
1307 if (result) {
1308 continue;
1309 }
1310
1311 TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op));
1312 changed |= result;
1313 if (result) {
1314 continue;
1315 }
1316
1317 bool merged_induction_vars = false;
1318 // Notably missing from this list are S16 and U16. These don't currently
1319 // work because S/U16 literals are not implemented.
1320 for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) {
1321 TF_ASSIGN_OR_RETURN(auto* new_while_op,
1322 TryMergeInductionVariables(while_op, elem_ty));
1323 if (new_while_op) {
1324 while_op = new_while_op;
1325 changed = true;
1326 merged_induction_vars = true;
1327 }
1328 }
1329 if (merged_induction_vars) {
1330 continue;
1331 }
1332 }
1333
1334 XLA_VLOG_LINES(3,
1335 "WhileLoopSimplifier::Run(), after:\n" + module->ToString());
1336 return changed;
1337 }
1338
1339 } // namespace xla
1340