• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/service/call_inliner.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
32 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
33 
34 namespace xla {
35 
36 namespace m = match;
37 using absl::optional;
38 using hlo_query::ContainsInstrWithOpcode;
39 
40 // A helper function that copy the raw JSON-encoded backend config from the old
41 // while op to the new one.
CopyBackendConfig(HloInstruction * old_while_op,HloInstruction * new_while_op)42 void CopyBackendConfig(HloInstruction* old_while_op,
43                        HloInstruction* new_while_op) {
44   new_while_op->set_raw_backend_config_string(
45       old_while_op->raw_backend_config_string());
46 }
47 
48 // A helper function that copy the frontend attributes from the old while op to
49 // the new one.
CopyFrontendAttributes(HloInstruction * old_while_op,HloInstruction * new_while_op)50 void CopyFrontendAttributes(HloInstruction* old_while_op,
51                             HloInstruction* new_while_op) {
52   new_while_op->add_frontend_attributes(old_while_op->frontend_attributes());
53 }
54 
55 // This is a utility function that removes the given tuple indices from the
56 // while loop init, body, and condition. The final shape returned is still the
57 // same as before.
RemoveDeadTupleIndices(HloInstruction * while_op,absl::flat_hash_set<int64> & used_tuple_indices)58 static StatusOr<HloInstruction*> RemoveDeadTupleIndices(
59     HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) {
60   // Build up maps from the old/new to the new/old tuple indices.
61   std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
62                                           used_tuple_indices.end());
63   absl::c_sort(new_to_old_tuple_idx);
64 
65   HloModule* module = while_op->GetModule();
66   HloComputation* computation = while_op->parent();
67   HloInstruction* while_init = while_op->mutable_operand(0);
68   HloComputation* while_cond = while_op->while_condition();
69   HloComputation* while_body = while_op->while_body();
70   HloInstruction* while_body_root = while_body->root_instruction();
71 
72   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
73 
74   absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
75   for (int64_t new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
76     int64_t old_idx = new_to_old_tuple_idx[new_idx];
77     old_to_new_tuple_idx[old_idx] = new_idx;
78     VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx;
79   }
80 
81   // Compute the shape of the while op after we remove the dead indices.
82   std::vector<Shape> new_while_tuple_elem_shapes;
83   new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size());
84   for (int64_t old_idx : new_to_old_tuple_idx) {
85     new_while_tuple_elem_shapes.push_back(
86         while_init->shape().tuple_shapes(old_idx));
87   }
88   Shape new_while_shape =
89       ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes);
90 
91   // Returns a map from elements in the computation to new instructions which
92   // replace the old instructions after we remove unused elements from the while
93   // tuple.
94   auto make_while_computation_replacements = [&](const HloComputation* comp) {
95     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
96         replacements;
97 
98     auto* param = comp->parameter_instruction(0);
99     replacements.emplace(param, HloInstruction::CreateParameter(
100                                     0, new_while_shape, param->name()));
101 
102     // Materialize param's users, since we're about to add new ones below.
103     std::vector<HloInstruction*> materialized_users(param->users().begin(),
104                                                     param->users().end());
105     for (const auto* user : materialized_users) {
106       // The while body root is handled separately.
107       if (user == while_body_root) {
108         continue;
109       }
110       CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement)
111           << user->ToString(print_no_metadata);
112 
113       int64_t old_idx = user->tuple_index();
114       auto new_idx_iter = old_to_new_tuple_idx.find(old_idx);
115       if (new_idx_iter != old_to_new_tuple_idx.end()) {
116         // This is a GTE of an index that survives.  Replace it.
117         replacements.emplace(
118             user, HloInstruction::CreateGetTupleElement(user->shape(), param,
119                                                         new_idx_iter->second));
120       } else {
121         // This is a GTE of an index that we've removed.  Remove it from the
122         // cloned computation.
123         CHECK(user->user_count() == 0 ||
124               user->user_count() == 1 &&
125                   user->users().front() == while_body_root)
126             << "Instruction " << user->ToString(print_no_metadata)
127             << " should be unused (except by root of while body), but has "
128                "users: {"
129             << absl::StrJoin(user->users(), ", ",
130                              [&](string* out, const HloInstruction* instr) {
131                                absl::StrAppend(
132                                    out, instr->ToString(print_no_metadata));
133                              })
134             << "}";
135 
136         replacements.emplace(user, nullptr);
137       }
138     }
139     return replacements;
140   };
141 
142   // Create the new while condition, body, and init value.
143   std::unique_ptr<HloComputation> new_while_cond =
144       while_cond->CloneWithReplacements(
145           make_while_computation_replacements(while_cond));
146 
147   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
148       while_body_replacements = make_while_computation_replacements(while_body);
149   std::vector<HloInstruction*> new_while_body_root_elems;
150   new_while_body_root_elems.reserve(new_to_old_tuple_idx.size());
151   for (int64_t old_idx : new_to_old_tuple_idx) {
152     new_while_body_root_elems.push_back(
153         while_body_root->mutable_operand(old_idx));
154   }
155   while_body_replacements.emplace(
156       while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
157   std::unique_ptr<HloComputation> new_while_body =
158       while_body->CloneWithReplacements(std::move(while_body_replacements));
159 
160   // Add a new while_init instruction that repackages the old while_init
161   // instruction's elements.  We rely on the AlgebraicSimplifier and DCE to
162   // clean this up in the common case where while_init is a tuple op.  (It's
163   // definitely tuple-shaped, but it's not necessarily a tuple op.)
164   std::vector<HloInstruction*> new_while_init_elems;
165   new_while_init_elems.reserve(new_to_old_tuple_idx.size());
166   for (int64_t old_idx : new_to_old_tuple_idx) {
167     new_while_init_elems.push_back(
168         computation->AddInstruction(HloInstruction::CreateGetTupleElement(
169             while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
170   }
171   auto* new_while_init = computation->AddInstruction(
172       HloInstruction::CreateTuple(new_while_init_elems));
173 
174   // Create the new while op.
175   auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
176       new_while_shape,
177       module->AddEmbeddedComputation(std::move(new_while_cond)),
178       module->AddEmbeddedComputation(std::move(new_while_body)),
179       new_while_init));
180   CopyBackendConfig(while_op, new_while_op);
181   CopyFrontendAttributes(while_op, new_while_op);
182 
183   // Create a tuple op that recreates the output of the old while op.  That is,
184   // we transform to
185   //
186   //  new_while_init   while_init
187   //       |              |
188   //       V              |
189   //   new_while          |
190   //       |              |
191   //       -------|   |----
192   //              V   V
193   //            new_tuple
194   //                |
195   //                V
196   //    (orig. users of while op)
197   //
198   // The tuple simplifier will then simplify this if possible, removing
199   // new_tuple and while_init.
200   std::vector<HloInstruction*> new_tuple_elems;
201   const int64_t tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
202   for (int64_t old_idx = 0; old_idx < tuple_size; ++old_idx) {
203     auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx);
204     if (new_tuple_idx_it != old_to_new_tuple_idx.end()) {
205       int64_t gte_idx = new_tuple_idx_it->second;
206       new_tuple_elems.push_back(
207           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
208               new_while_op->shape().tuple_shapes(gte_idx), new_while_op,
209               gte_idx)));
210     } else {
211       new_tuple_elems.push_back(
212           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
213               while_init->shape().tuple_shapes(old_idx), while_init, old_idx)));
214     }
215   }
216   HloInstruction* new_tuple =
217       computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems));
218   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple));
219 
220   return new_while_op;
221 }
222 
223 // Tries to remove elements in a while loop's tuple that aren't used within the
224 // loop.
225 //
226 // Specifically, if a loop is tuple-shaped, and there exists some element of
227 // that tuple that is not used by the loop condition and is not used by the loop
228 // body except to pass it to the next iteration of the loop, then we can remove
229 // that element from the loop's tuples.
TryRemoveDeadWhileParams(HloInstruction * while_op)230 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
231   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
232 
233   // Don't try this transformation if the while loop isn't removable, since if
234   // it succeeds ultimately we're going to have to replace the old while loop
235   // with a new one.
236   if (!while_op->parent()->IsSafelyRemovable(while_op)) {
237     VLOG(2) << "Can't remove dead parameters from non-removable while op.";
238     return false;
239   }
240 
241   HloInstruction* while_init = while_op->mutable_operand(0);
242   HloComputation* while_cond = while_op->while_condition();
243   HloComputation* while_body = while_op->while_body();
244   HloInstruction* while_body_root = while_body->root_instruction();
245 
246   if (!while_init->shape().IsTuple()) {
247     VLOG(2) << "While op's carried value isn't tuple shaped.";
248     return false;
249   }
250 
251   if (while_body_root->opcode() != HloOpcode::kTuple) {
252     VLOG(2) << "While body's root is not a tuple(...) instruction.";
253     return false;
254   }
255 
256   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
257 
258   // Bail if param0 of while_cond or while_body has users which aren't of type
259   // get-tuple-element.
260   for (const HloInstruction* instr : {while_body->parameter_instruction(0),
261                                       while_cond->parameter_instruction(0)}) {
262     for (const HloInstruction* user : instr->users()) {
263       if (user->opcode() != HloOpcode::kGetTupleElement) {
264         VLOG(2) << "Cowardly refusing to analyze while loop with "
265                 << instr->ToString(print_no_metadata)
266                 << " used by non-GTE instruction "
267                 << user->ToString(print_no_metadata) << " in computation "
268                 << instr->parent()->name();
269         return false;
270       }
271     }
272   }
273 
274   const int64_t tuple_size = ShapeUtil::TupleElementCount(while_init->shape());
275   if (tuple_size == 0) {
276     VLOG(2) << "Can't remove elements from while loop's tuple -- it's already "
277                "empty.";
278     return false;
279   }
280 
281   absl::flat_hash_set<int64> used_tuple_indices;
282   for (HloComputation* comp : {while_body, while_cond}) {
283     // The HLO verifier ensures that while_input's shape matches while_init's
284     // shape, which we verified above is a tuple.
285     HloInstruction* while_input = comp->parameter_instruction(0);
286 
287     for (const HloInstruction* user : while_input->users()) {
288       // This user doesn't count if it's only used by the while body's root, and
289       // the root places the tuple element into the same index of the tuple as
290       // it came from.  That just amounts to us carrying the variable through
291       // the loop.
292       //
293       // Careful: HloInstruction::operand_index returns the first index the
294       // operand appears in, but it may appear more than once!
295       if (user->user_count() == 1 && user->users().front() == while_body_root &&
296           while_body_root->operand_index(user) == user->tuple_index() &&
297           absl::c_count(while_body_root->operands(), user) == 1) {
298         continue;
299       }
300 
301       used_tuple_indices.insert(user->tuple_index());
302       if (used_tuple_indices.size() == tuple_size) {
303         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
304                 << " uses all of its inputs; no simplification possible.";
305         return false;
306       }
307     }
308   }
309 
310   absl::flat_hash_set<int64> used_indices_after_loop;
311   if (while_op == while_op->parent()->root_instruction()) {
312     for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
313       used_indices_after_loop.insert(i);
314     }
315   }
316   for (auto user : while_op->users()) {
317     if (user->opcode() != HloOpcode::kGetTupleElement) {
318       for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
319         used_indices_after_loop.insert(i);
320       }
321       break;
322     }
323     used_indices_after_loop.insert(user->tuple_index());
324   }
325   // If a tuple element is used after the loop but not passed unmodified from
326   // the while body's param0 through to the while body's root, count that
327   // element as "used", since removing that element would be observable.
328   for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
329     if (used_tuple_indices.contains(i) ||
330         !used_indices_after_loop.contains(i)) {
331       continue;
332     }
333 
334     auto* operand = while_body_root->operand(i);
335     if (operand->opcode() != HloOpcode::kGetTupleElement ||
336         operand->operand(0) != while_body->parameter_instruction(0) ||
337         operand->tuple_index() != i) {
338       VLOG(2) << "Tuple index " << i
339               << " is not passed through loop body unmodified.";
340       used_tuple_indices.insert(i);
341 
342       if (used_tuple_indices.size() == tuple_size) {
343         VLOG(2) << "Loop " << while_op->ToString(print_no_metadata)
344                 << " uses all of its inputs; no simplification possible.";
345         return false;
346       }
347     }
348   }
349 
350   // If we got here, used_tuple_indices.size() < tuple_size, meaning some
351   // elements of the loop's tuple aren't used by while_body or while_cond.
352   CHECK_LT(used_tuple_indices.size(), tuple_size);
353 
354   VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size()
355           << " elements from tuple of "
356           << while_op->ToString(print_no_metadata);
357 
358   TF_ASSIGN_OR_RETURN(while_op,
359                       RemoveDeadTupleIndices(while_op, used_tuple_indices));
360 
361   return true;
362 }
363 
364 // This is a helper function for TryRemoveRepeatedWhileTupleIndices. It removes
365 // duplicates by replacing them with tuple_index, followed by a call to
366 // RemoveDeadTupleIndices.
TryRemoveRepeatedWhileTupleIndicesHelper(HloInstruction * while_op,const int64_t tuple_index,absl::flat_hash_set<int64> & duplicates)367 static StatusOr<HloInstruction*> TryRemoveRepeatedWhileTupleIndicesHelper(
368     HloInstruction* while_op, const int64_t tuple_index,
369     absl::flat_hash_set<int64>& duplicates) {
370   HloComputation* while_cond = while_op->while_condition();
371   HloComputation* while_body = while_op->while_body();
372   HloInstruction* while_init = while_op->mutable_operand(0);
373 
374   VLOG(2) << "while_init " << while_init->ToString() << " operands "
375           << while_init->operand_count();
376   VLOG(2) << "while_body_root " << while_body->root_instruction()->ToString()
377           << " operands " << while_body->root_instruction()->operand_count();
378 
379   // Change the loop body and condition such that uses of the duplicates are
380   // replaced with the original tuple element.
381   for (HloComputation* comp : {while_body, while_cond}) {
382     auto new_get = comp->AddInstruction(HloInstruction::CreateGetTupleElement(
383         comp->parameter_instruction(0)->shape().tuple_shapes(tuple_index),
384         comp->parameter_instruction(0), tuple_index));
385 
386     std::vector<HloInstruction*> instrs_to_replace;
387     for (auto* instr : comp->instructions()) {
388       if (instr->opcode() == HloOpcode::kGetTupleElement &&
389           duplicates.contains(instr->tuple_index()) &&
390           instr->operand(0) == comp->parameter_instruction(0)) {
391         instrs_to_replace.push_back(instr);
392       }
393     }
394 
395     for (auto instr : instrs_to_replace) {
396       TF_RETURN_IF_ERROR(comp->ReplaceInstruction(instr, new_get));
397     }
398   }
399 
400   // We know which tuple indices are useful; i.e, those which aren't duplicates.
401   absl::flat_hash_set<int64> used_tuple_indices;
402   for (int index = 0; index < while_init->shape().tuple_shapes_size();
403        ++index) {
404     if (!duplicates.count(index)) {
405       used_tuple_indices.insert(index);
406     }
407   }
408 
409   // Remove the duplicate tuple elements.
410   TF_ASSIGN_OR_RETURN(while_op,
411                       RemoveDeadTupleIndices(while_op, used_tuple_indices));
412 
413   return while_op;
414 }
415 
416 // If the while loop init passes the same values to several tuple indices, and
417 // if the body keeps on passing them through, we can remove the duplicates.
TryRemoveRepeatedWhileTupleIndices(HloInstruction * while_op)418 static StatusOr<bool> TryRemoveRepeatedWhileTupleIndices(
419     HloInstruction* while_op) {
420   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
421 
422   int index_to_investigate = 0;
423   // Don't try this transformation if the while loop isn't removable, since if
424   // it succeeds ultimately we're going to have to replace the old while loop
425   // with a new one.
426   if (!while_op->parent()->IsSafelyRemovable(while_op)) {
427     VLOG(2) << "Can't remove dead parameters from non-removable while op.";
428     return false;
429   }
430 
431   HloInstruction* while_init = while_op->mutable_operand(0);
432   HloComputation* while_cond = while_op->while_condition();
433   HloComputation* while_body = while_op->while_body();
434   HloInstruction* while_body_root = while_body->root_instruction();
435 
436   if (!while_init->shape().IsTuple()) {
437     VLOG(2) << "While op's carried value isn't tuple shaped.";
438     return false;
439   }
440 
441   bool changed = false;
442   while (index_to_investigate < while_init->shape().tuple_shapes_size()) {
443     if (!while_init->shape().IsTuple() ||
444         while_init->opcode() != HloOpcode::kTuple) {
445       VLOG(2) << "While op's carried value isn't tuple shaped.";
446       return false;
447     }
448 
449     if (while_body_root->opcode() != HloOpcode::kTuple) {
450       VLOG(2) << "While body's root is not a tuple(...) instruction.";
451       return false;
452     }
453 
454     auto& while_shape = while_init->shape();
455     VLOG(2) << "Iterating " << index_to_investigate;
456 
457     absl::flat_hash_set<int64> duplicates;
458     auto* pivot_init_elem = while_init->operand(index_to_investigate);
459     auto* pivot_body_elem = while_body_root->operand(index_to_investigate);
460     if (pivot_body_elem->opcode() == HloOpcode::kGetTupleElement &&
461         pivot_body_elem->operand(0) == while_body->parameter_instruction(0)) {
462       if (pivot_body_elem->tuple_index() != index_to_investigate) {
463         VLOG(2) << "Mismatch between pivot_body_elem->tuple_index() "
464                 << pivot_body_elem->tuple_index() << " index_to_investigate "
465                 << index_to_investigate;
466         index_to_investigate++;
467         continue;
468       }
469     } else {
470       index_to_investigate++;
471       continue;
472     }
473 
474     // Look from index_to_investigate onwards to see if it is repeated.
475     for (int64_t i = index_to_investigate + 1;
476          i < while_shape.tuple_shapes_size(); ++i) {
477       auto* init_elem = while_init->operand(i);
478       auto* body_elem = while_body_root->operand(i);
479       if (body_elem->opcode() == HloOpcode::kGetTupleElement &&
480           body_elem->operand(0) == while_body->parameter_instruction(0)) {
481         if (body_elem->tuple_index() != i) {
482           VLOG(2) << "Mismatch between body_elem->tuple_index() "
483                   << body_elem->tuple_index() << " i " << i;
484           continue;
485         }
486       } else {
487         continue;
488       }
489 
490       if (pivot_init_elem == init_elem) {
491         VLOG(2) << "init_elem " << init_elem->ToString() << " pivot_init_elem "
492                 << pivot_init_elem->ToString();
493         VLOG(2) << "body_elem " << body_elem->ToString() << " pivot_body_elem "
494                 << pivot_body_elem->ToString();
495         duplicates.insert(i);
496       }
497     }
498 
499     // If duplicates are found, call the helper to remove them.
500     if (!duplicates.empty()) {
501       VLOG(2) << "Duplicate found " << duplicates.size() << " pivot_init "
502               << pivot_init_elem->ToString();
503       TF_ASSIGN_OR_RETURN(while_op,
504                           TryRemoveRepeatedWhileTupleIndicesHelper(
505                               while_op, index_to_investigate, duplicates));
506       changed = true;
507       VLOG(2) << "Changed while_op " << while_op->ToString()
508               << " while_op operand count " << while_op->operand_count();
509       // Update the while loop variables so we can continue looking for
510       // duplicates of a different index.
511       while_init = while_op->mutable_operand(0);
512       while_cond = while_op->while_condition();
513       while_body = while_op->while_body();
514       while_body_root = while_body->root_instruction();
515     }
516     index_to_investigate++;
517   }
518 
519   return changed;
520 }
521 
522 // Removes each loop parameter (i.e. member of the while loop tuple) that is a
523 // constant and is the same in the while loop body and the while loop init.
TryRemoveConstantParams(HloInstruction * while_op)524 static StatusOr<bool> TryRemoveConstantParams(HloInstruction* while_op) {
525   HloModule* module = while_op->GetModule();
526   HloComputation* computation = while_op->parent();
527   auto* while_init = while_op->mutable_operand(0);
528   auto* while_body = while_op->while_body();
529   auto* while_cond = while_op->while_condition();
530   auto* while_body_root = while_body->root_instruction();
531   if (while_init->opcode() != HloOpcode::kTuple ||
532       while_body_root->opcode() != HloOpcode::kTuple) {
533     return false;
534   }
535 
536   TF_RET_CHECK(while_cond->num_parameters() == 1);
537   TF_RET_CHECK(while_body->num_parameters() == 1);
538   TF_RET_CHECK(
539       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
540 
541   absl::flat_hash_set<int64> constant_tuple_indices;
542   const auto& while_shape = while_init->shape();
543   for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
544     auto* init_elem = while_init->operand(i);
545     auto* body_elem = while_body_root->operand(i);
546     if (init_elem->opcode() == HloOpcode::kConstant &&
547         body_elem->opcode() == HloOpcode::kConstant &&
548         init_elem->literal() == body_elem->literal()) {
549       constant_tuple_indices.insert(i);
550     }
551   }
552 
553   if (constant_tuple_indices.empty()) {
554     return false;
555   }
556 
557   // OK, we found some constant elements of the while parameter!  Eliminate
558   // them.
559   std::vector<Shape> new_while_shape_elems;
560   for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
561     if (!constant_tuple_indices.count(i)) {
562       new_while_shape_elems.push_back(while_shape.tuple_shapes(i));
563     }
564   }
565   Shape new_while_shape = ShapeUtil::MakeTupleShape(new_while_shape_elems);
566 
567   // `new_instrs` holds instructions created outside of a computation for
568   // cloning.  Elements added here just need to live until the end of the
569   // relevant CloneWithReplacement call.
570   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
571   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
572     new_instrs.push_back(std::move(instr));
573     return new_instrs.back().get();
574   };
575 
576   // Returns a new tuple without the elements of constant_tuple_indices.
577   auto remove_constant_elems = [&](HloInstruction* instr) {
578     CHECK(ShapeUtil::Compatible(instr->shape(), while_shape));
579 
580     std::vector<HloInstruction*> tuple_elems;
581     for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
582       if (!constant_tuple_indices.count(i)) {
583         tuple_elems.push_back(
584             add_new_instr(HloInstruction::CreateGetTupleElement(
585                 while_shape.tuple_shapes(i), instr, i)));
586       }
587     }
588     return HloInstruction::CreateTuple(tuple_elems);
589   };
590 
591   auto add_constant_elems = [&](HloInstruction* instr) {
592     CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
593 
594     std::vector<HloInstruction*> tuple_elems;
595     int64_t j = 0;
596     for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
597       if (constant_tuple_indices.count(i)) {
598         tuple_elems.push_back(while_init->mutable_operand(i));
599       } else {
600         tuple_elems.push_back(
601             add_new_instr(HloInstruction::CreateGetTupleElement(
602                 while_shape.tuple_shapes(i), instr, j)));
603         ++j;
604       }
605     }
606     return HloInstruction::CreateTuple(tuple_elems);
607   };
608 
609   // Special case: constant_tuple_indices covers the whole while parameter, so
610   // the new while shape is the empty tuple.  In this case, the value of the
611   // while loop is simply equal to the value of `init`.
612   //
613   // It's unfortunate to special-case this, but it's simpler than the
614   // alternative.  The problem is that if our while parameter has no
615   // non-constant elems, the tuple returned by `add_constant_elems` won't depend
616   // on instr (the loop body/cond parameter), and therefore
617   // CloneWithReplacementPairs will *leave the parameter out entirely*, creating
618   // invalid HLO.
619   if (ShapeUtil::IsEmptyTuple(new_while_shape)) {
620     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, while_init));
621     return true;
622   }
623 
624   std::unique_ptr<HloComputation> new_while_cond =
625       while_cond->CloneWithReplacementPairs({
626           while_cond->parameter_instruction(0),
627           add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
628               0, new_while_shape,
629               while_cond->parameter_instruction(0)->name()))),
630       });
631 
632   std::unique_ptr<HloComputation> new_while_body =
633       while_body->CloneWithReplacementPairs(
634           {
635               while_body->parameter_instruction(0),
636               add_constant_elems(add_new_instr(HloInstruction::CreateParameter(
637                   0, new_while_shape,
638                   while_cond->parameter_instruction(0)->name()))),
639           },
640           {
641               while_body->root_instruction(),
642               remove_constant_elems(
643                   add_new_instr(while_body->root_instruction()->Clone())),
644           });
645 
646   // Create the final while loop, and add any new instructions created to
647   // `computation`.
648   new_instrs.clear();
649   auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
650       new_while_shape,
651       module->AddEmbeddedComputation(std::move(new_while_cond)),
652       module->AddEmbeddedComputation(std::move(new_while_body)),
653       add_new_instr(remove_constant_elems(while_init))));
654   CopyBackendConfig(while_op, new_while_op);
655   CopyFrontendAttributes(while_op, new_while_op);
656   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
657       while_op, add_constant_elems(new_while_op)));
658   for (auto& instr : new_instrs) {
659     computation->AddInstruction(std::move(instr));
660   }
661   return true;
662 }
663 
664 // Tries to remove a while loop from the graph.
665 //
666 //  - Loops with trip count of 0 can be replaced by the loop's "init" value.
667 //  - Loops with trip count of 1 can be replaced by the loop's body, with the
668 //    loop itself removed.
669 //
670 // Returns true if it made a change to the graph.
TryRemoveWhileLoop(HloInstruction * while_op)671 static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
672   // Cowardly refuse to remove loops that are not removable.  In practice, this
673   // means that we can't remove loops that have control predecessors/successors.
674   if (!while_op->parent()->IsSafelyRemovable(while_op)) {
675     VLOG(2) << "Not attempting to remove while loop that is not removable: "
676             << while_op->ToShortString();
677     return false;
678   }
679 
680   // Refuse to remove while loops with a condition that contain side-effects,
681   // because removing a while loop is tantamount to removing its condition.
682   //
683   // TODO(jlebar): This is conservative: We could instead just run the while
684   // condition once (trip-count == 0) or twice (trip-count == 1).
685   if (while_op->while_condition()->HasSideEffect()) {
686     VLOG(2) << "Not attempting to remove while loop whose condition contains "
687                "side-effecting instructions: "
688             << while_op->ToShortString();
689     return false;
690   }
691 
692   // Remove while loops with static trip count of 0.
693   optional<int64> trip_count =
694       ComputeWhileLoopTripCount(while_op, /*max_brute_force_iters=*/1);
695   if (trip_count && *trip_count == 0) {
696     // The loop never executes, so the value of the loop is the value of its
697     // "init" operand.
698     auto computation = while_op->parent();
699 
700     // Remove while_op (i.e., call ReplaceInstruction rather than
701     // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in
702     // a loop without an intervening DCE, we don't try to re-remove the loop.
703     TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
704         while_op, while_op->mutable_operand(0)));
705     return true;
706   }
707 
708   // Transform while loops with static trip count of 1 into a call op, then
709   // inline the call.
710   const auto& attrs = while_op->frontend_attributes().map();
711   bool skip_trip_count_one_simplification =
712       attrs.contains("skip-simplify-while-loops/trip-count-one") &&
713       (attrs.at("skip-simplify-while-loops/trip-count-one") == "true");
714   if (trip_count && *trip_count == 1 && !skip_trip_count_one_simplification) {
715     // Do not simplify the loop away when there is a side-effectful op,
716     // otherwise the infeed op may not inherit the data dependency from
717     // the while loop.
718     //
719     // Example: while_body (param_a) {
720     //   param_a = parameter(0)
721     //   infeed2 = infeed()
722     // }
723     //
724     // infeed1 = ...
725     // while = while(infeed1), body=while_body // infeed2 has implicit
726     // dependency on infeed1.
727     //
728     // After simplification:
729     //
730     // infeed1 = ...
731     // infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
732     //                    // can be scheduled after infeed2.
733     //
734     bool has_side_effects = absl::c_any_of(
735         while_op->called_computations(), [](const HloComputation* computation) {
736           return computation->HasSideEffect();
737         });
738     if (!has_side_effects) {
739       auto computation = while_op->parent();
740       auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
741           while_op->shape(), while_op->operands(), while_op->while_body()));
742       TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
743       TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
744                           CallInliner::Inline(call_op));
745       (void)inlined_instructions_map;
746       return true;
747     } else {
748       VLOG(2) << "Not attempting to simplify while loop because it contains a "
749                  "side-effecting node: "
750               << while_op->ToShortString();
751     }
752   }
753   return false;
754 }
755 
TryPropagateConstant(HloInstruction * while_op)756 static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
757   auto while_init = while_op->operand(0);
758   if (while_init->opcode() != HloOpcode::kTuple) {
759     return false;
760   }
761 
762   auto while_body = while_op->while_body();
763   auto while_body_root = while_body->root_instruction();
764   if (while_body_root->opcode() != HloOpcode::kTuple) {
765     return false;
766   }
767 
768   auto while_body_param = while_body->parameter_instruction(0);
769   const HloInstruction::InstructionVector& root_operands =
770       while_body_root->operands();
771 
772   // Find the loop invariant tuple elements with scalar constant init value and
773   // build a map from the tuple element index to the constant value. Limit this
774   // to scalar constant values because propagating array constants can regress
775   // performance by forcing us to copy constants.
776   absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
777   for (int i = 0; i < root_operands.size(); i++) {
778     const HloInstruction* init_tuple_elem = nullptr;
779     if (Match(root_operands[i],
780               m::GetTupleElement(m::Op().Is(while_body_param), i)
781                   .WithShape(m::Shape().IsScalar())) &&
782         Match(while_init->operand(i), m::Constant(&init_tuple_elem))) {
783       VLOG(3) << "Found loop invariant tuple element " << i << " "
784               << init_tuple_elem->ToString();
785       index_to_constant[i] = init_tuple_elem;
786     }
787   }
788 
789   if (index_to_constant.empty()) {
790     return false;
791   }
792 
793   // Replace the use of each constant tuple element in the loop_condition and
794   // loop_body with the corresponding constant value.
795   auto propagate_constant = [&](HloComputation* computation) -> StatusOr<bool> {
796     HloInstruction* param = computation->parameter_instruction(0);
797     bool changed = false;
798     for (auto instr : param->users()) {
799       // Since only a while-loop with a tuple result reaches here, we can safely
800       // assume that `param` is a tuple and the first operand of the
801       // GetTupleElement instruction is a use of `param`.
802       if (instr->opcode() == HloOpcode::kGetTupleElement) {
803         VLOG(3) << "tuple index " << instr->tuple_index() << " "
804                 << instr->ToString();
805         auto iter = index_to_constant.find(instr->tuple_index());
806         if (iter != index_to_constant.end()) {
807           const HloInstruction* hlo_constant = (*iter).second;
808           VLOG(3) << "Replace use of " << instr->ToString() << " with "
809                   << hlo_constant->ToString();
810           TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(
811               computation->AddInstruction(hlo_constant->Clone())));
812           changed = true;
813         }
814       }
815     }
816     return changed;
817   };
818 
819   TF_ASSIGN_OR_RETURN(bool changed_cond,
820                       propagate_constant(while_op->while_condition()));
821   TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body));
822 
823   return changed_cond || changed_body;
824 }
825 
826 // Converts a flat list of instructions into a tuple of the desired shape.  For
827 // example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns
828 // a tuple of value ((A, B), C).
829 //
830 // desired_shape must be a tuple.  (This precondition allows us to return a
831 // unique_ptr rather than a raw ptr.)
UnflattenTupleInstr(absl::Span<HloInstruction * > instrs,const Shape & desired_shape,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)832 static std::unique_ptr<HloInstruction> UnflattenTupleInstr(
833     absl::Span<HloInstruction*> instrs, const Shape& desired_shape,
834     std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
835   CHECK(desired_shape.IsTuple()) << ShapeUtil::HumanString(desired_shape);
836 
837   // For each child shape in `desired_shape`, slice out the correct number of
838   // `instrs` and call UnflattenTupleInstr recursively.  At each step we remove
839   // elements from `instrs` so that it only contains instructions we have not
840   // yet processed.
841   std::vector<HloInstruction*> elems;
842   for (int64_t i = 0; i < desired_shape.tuple_shapes_size(); ++i) {
843     const Shape& subshape = desired_shape.tuple_shapes(i);
844     if (!subshape.IsTuple()) {
845       elems.push_back(instrs[0]);
846       instrs.remove_prefix(1);
847       continue;
848     }
849 
850     // Count the number of leaf nodes underneath desired_shape[i].
851     int64_t num_leaves = 0;
852     ShapeUtil::ForEachSubshape(
853         subshape, [&](const Shape& s, const ShapeIndex& /*index*/) {
854           if (!s.IsTuple()) {
855             ++num_leaves;
856           }
857         });
858 
859     std::unique_ptr<HloInstruction> subinstr =
860         UnflattenTupleInstr(instrs.subspan(0, num_leaves),
861                             desired_shape.tuple_shapes(i), new_instrs);
862     elems.push_back(subinstr.get());
863     new_instrs->push_back(std::move(subinstr));
864     instrs.remove_prefix(num_leaves);
865   }
866   return HloInstruction::CreateTuple(elems);
867 }
868 
869 // Builds a vector whose elements are the values in the flattened tuple for
870 // `instr`.  For example, if `instr` is a tuple of form ((A, B), C), returns the
871 // vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C).
GetFlatTupleElems(HloInstruction * instr,std::vector<std::unique_ptr<HloInstruction>> * new_instrs)872 static std::vector<HloInstruction*> GetFlatTupleElems(
873     HloInstruction* instr,
874     std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
875   const auto& shape = instr->shape();
876   if (!shape.IsTuple()) {
877     return {instr};
878   }
879   std::vector<HloInstruction*> elems;
880   for (int64_t i = 0; i < shape.tuple_shapes_size(); ++i) {
881     const Shape& subshape = shape.tuple_shapes(i);
882     new_instrs->push_back(
883         HloInstruction::CreateGetTupleElement(subshape, instr, i));
884     auto* gte = new_instrs->back().get();
885     auto flattened_subshape = GetFlatTupleElems(gte, new_instrs);
886     elems.insert(elems.end(), flattened_subshape.begin(),
887                  flattened_subshape.end());
888   }
889   return elems;
890 }
891 
TryFlattenNestedTuples(HloInstruction * while_op)892 static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
893   HloModule* module = while_op->GetModule();
894   HloComputation* computation = while_op->parent();
895   auto* while_init = while_op->mutable_operand(0);
896   auto* while_body = while_op->while_body();
897   auto* while_cond = while_op->while_condition();
898   auto* while_body_root = while_body->root_instruction();
899   if (while_init->opcode() != HloOpcode::kTuple ||
900       while_body_root->opcode() != HloOpcode::kTuple) {
901     return false;
902   }
903 
904   TF_RET_CHECK(while_cond->num_parameters() == 1);
905   TF_RET_CHECK(while_body->num_parameters() == 1);
906   TF_RET_CHECK(
907       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
908   Shape while_shape = while_init->shape();
909   if (!ShapeUtil::IsNestedTuple(while_shape)) {
910     return false;
911   }
912 
913   std::vector<Shape> flattened_shape_elems;
914   ShapeUtil::ForEachSubshape(while_shape,
915                              [&](const Shape& s, const ShapeIndex& /*index*/) {
916                                if (!s.IsTuple()) {
917                                  flattened_shape_elems.push_back(s);
918                                }
919                              });
920   Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems);
921 
922   // `new_instrs` holds instructions created outside of a computation for
923   // cloning.  Elements added here just need to live until the end of the
924   // relevant CloneWithReplacement call.
925   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
926   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
927     new_instrs.push_back(std::move(instr));
928     return new_instrs.back().get();
929   };
930 
931   auto nested = [&](HloInstruction* instr) {
932     std::vector<HloInstruction*> gtes;
933     const Shape& flat_shape = instr->shape();
934     for (int64_t i = 0; i < flat_shape.tuple_shapes_size(); ++i) {
935       gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement(
936           flat_shape.tuple_shapes(i), instr, i)));
937     }
938     auto nested_instr =
939         UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs);
940     CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape))
941         << ShapeUtil::HumanString(nested_instr->shape()) << " vs "
942         << ShapeUtil::HumanString(while_shape);
943     return nested_instr;
944   };
945 
946   auto flattened = [&](HloInstruction* instr) {
947     return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs));
948   };
949 
950   // Create a new while-condition computation, where parameter 0 has flat shape
951   // but all uses of it go through the nested shape.
952   std::unique_ptr<HloComputation> new_while_cond =
953       while_cond->CloneWithReplacementPairs({
954           while_cond->parameter_instruction(0),
955           nested(add_new_instr(HloInstruction::CreateParameter(
956               0, flattened_shape,
957               while_cond->parameter_instruction(0)->name()))),
958       });
959 
960   // Create a new while-body computation, where parameter 0 has a flat shape and
961   // all uses of it go through the nested shape, and where the root has a flat
962   // shape constructed from the old nested root.
963   std::unique_ptr<HloComputation> new_while_body =
964       while_body->CloneWithReplacementPairs(
965           {
966               while_body->parameter_instruction(0),
967               nested(add_new_instr(HloInstruction::CreateParameter(
968                   0, flattened_shape,
969                   while_body->parameter_instruction(0)->name()))),
970           },
971           {
972               while_body->root_instruction(),
973               flattened(add_new_instr(while_body->root_instruction()->Clone())),
974           });
975 
976   // Create the final while loop, and add any new instructions created to
977   // `computation`.
978   new_instrs.clear();
979   auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile(
980       flattened_shape,
981       module->AddEmbeddedComputation(std::move(new_while_cond)),
982       module->AddEmbeddedComputation(std::move(new_while_body)),
983       computation->AddInstruction(flattened(while_init))));
984   CopyBackendConfig(while_op, new_while_op);
985   CopyFrontendAttributes(while_op, new_while_op);
986   TF_RETURN_IF_ERROR(
987       computation->ReplaceWithNewInstruction(while_op, nested(new_while_op)));
988   for (auto& instr : new_instrs) {
989     computation->AddInstruction(std::move(instr));
990   }
991   return true;
992 }
993 
994 // Tries to merge loop induction variables of a given type.
995 //
996 // In this pass we're only concerned with elements of the loop's tuple that
997 // are effective-scalars of type `elem_ty`.  Some terminology:
998 //
999 //  - The trip counter is the first element of the loop's tuple that starts at
1000 //    0 and does x++ on each iteration.
1001 //
1002 //  - An induction variable is an element of the loop's tuple that is not the
1003 //    trip counter and does `x += <constant>` on each iteration of the loop.
1004 //    Negative constants are OK.
1005 //
1006 // This pass adds a trip counter if one isn't already present, then replaces
1007 // each induction variable with
1008 //
1009 //   <initial_value> + <trip_count> * <constant>.
1010 //
1011 // This reduces the number of scalar operations in the loop, which is important
1012 // e.g. on GPUs, where each scalar operation is nontrivially expensive because
1013 // it's a separate kernel launch.
1014 //
1015 // Returns the new loop if a change was made, or null if no change was made.
1016 // Note that the new loop is not a valid replacement for the old loop; it may
1017 // need to be wrapped in a tuple that changes its shape.  We return the loop
1018 // itself so that you can call TryMergeInductionVariables in a loop, once for
1019 // each integral type elem_ty.
TryMergeInductionVariables(HloInstruction * while_op,PrimitiveType elem_ty)1020 static StatusOr<HloInstruction*> TryMergeInductionVariables(
1021     HloInstruction* while_op, PrimitiveType elem_ty) {
1022   CHECK(primitive_util::IsIntegralType(elem_ty)) << PrimitiveType_Name(elem_ty);
1023   HloModule* module = while_op->GetModule();
1024   HloComputation* computation = while_op->parent();
1025   auto* while_init = while_op->mutable_operand(0);
1026   auto* while_body = while_op->while_body();
1027   auto* while_cond = while_op->while_condition();
1028   auto* while_body_root = while_body->root_instruction();
1029   if (while_init->opcode() != HloOpcode::kTuple ||
1030       while_body_root->opcode() != HloOpcode::kTuple) {
1031     return nullptr;
1032   }
1033 
1034   TF_RET_CHECK(while_cond->num_parameters() == 1);
1035   TF_RET_CHECK(while_body->num_parameters() == 1);
1036   TF_RET_CHECK(
1037       ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
1038   Shape while_shape = while_init->shape();
1039 
1040   // The tuple index of the trip counter, if one is present.
1041   absl::optional<int64> trip_counter;
1042   // Maps the tuple index of each induction variable to its constant increment.
1043   absl::flat_hash_map<int64, const HloConstantInstruction*> induction_vars;
1044   for (int64_t i = 0; i < while_body_root->operand_count(); ++i) {
1045     HloInstruction* constant;
1046     if (!Match(while_body_root->mutable_operand(i),
1047                m::AddAnyOrder(m::GetTupleElement(m::Parameter(), i),
1048                               m::ConstantScalar(&constant))
1049                    .WithShape(m::Shape().WithElementType(elem_ty)))) {
1050       continue;
1051     }
1052     if (!trip_counter && constant->literal().IsAll(1) &&
1053         while_init->operand(i)->IsConstant() &&
1054         while_init->operand(i)->literal().IsAll(0)) {
1055       VLOG(10) << "Found existing trip counter at index " << i;
1056       trip_counter = i;
1057     } else {
1058       VLOG(10) << "Found induction variable at index " << i;
1059       induction_vars.emplace(i, Cast<HloConstantInstruction>(constant));
1060     }
1061   }
1062 
1063   // There's only something to simplify if we can either:
1064   //
1065   //  - combine one or more induction vars with an existing trip counter, or
1066   //  - replace two or more induction variables with a new trip counter.
1067   //
1068   // Put another way, there's only something to simplify if the number of
1069   // induction vars plus the number of existing trip counters (0 or 1) is >= 2.
1070   if (induction_vars.size() + (trip_counter.has_value() ? 1 : 0) < 2) {
1071     return nullptr;
1072   }
1073 
1074   // OK, we're going to do the transformation!  Set up some helpers.
1075 
1076   // `new_instrs` holds instructions created outside of a computation for
1077   // cloning.  Elements added here just need to live until the end of the
1078   // relevant CloneWithReplacement call.
1079   std::vector<std::unique_ptr<HloInstruction>> new_instrs;
1080   auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
1081     new_instrs.push_back(std::move(instr));
1082     return new_instrs.back().get();
1083   };
1084 
1085   auto add_binary_op = [&](const Shape& shape, HloOpcode opcode,
1086                            HloInstruction* lhs, HloInstruction* rhs) {
1087     // Reshape lhs/rhs to the output shape if necessary.  This deals with the
1088     // fact that induction variables need only be effective scalars, not true
1089     // scalars.
1090     if (!ShapeUtil::Compatible(shape, lhs->shape())) {
1091       lhs = add_new_instr(HloInstruction::CreateReshape(shape, lhs));
1092     }
1093     if (!ShapeUtil::Compatible(shape, rhs->shape())) {
1094       rhs = add_new_instr(HloInstruction::CreateReshape(shape, rhs));
1095     }
1096     return add_new_instr(HloInstruction::CreateBinary(shape, opcode, lhs, rhs));
1097   };
1098 
1099   auto add_gte = [&](HloInstruction* src, int64_t idx) {
1100     return add_new_instr(HloInstruction::CreateGetTupleElement(
1101         src->shape().tuple_shapes(idx), src, idx));
1102   };
1103 
1104   // Our new while loop will have the same shape as the old while loop, except
1105   // we'll add a trip counter to the end if it wasn't originally present.
1106   Shape new_while_shape = while_shape;
1107   bool added_trip_counter = false;
1108   if (!trip_counter) {
1109     VLOG(10) << "Adding new trip counter to end of loop's tuple.";
1110     trip_counter = new_while_shape.tuple_shapes_size();
1111     *new_while_shape.add_tuple_shapes() =
1112         ShapeUtil::MakeShape(elem_ty, /*dimensions=*/{});
1113     added_trip_counter = true;
1114   }
1115 
1116   // Converts `instr` into a tuple of the "old" form -- that is, to a tuple with
1117   // shape `while_body->shape()` and where the induction variables are "reified"
1118   // (i.e. they have value <init> + <counter> * <constant>).
1119   auto convert_to_old_form = [&](HloInstruction* instr) {
1120     CHECK(ShapeUtil::Compatible(instr->shape(), new_while_shape));
1121     std::vector<HloInstruction*> tuple_elems;
1122     for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1123       const auto& elem_shape = while_shape.tuple_shapes(i);
1124       if (!induction_vars.count(i)) {
1125         tuple_elems.push_back(add_gte(instr, i));
1126         continue;
1127       }
1128       tuple_elems.push_back(add_binary_op(
1129           elem_shape, HloOpcode::kAdd, add_gte(instr, i),
1130           add_binary_op(elem_shape, HloOpcode::kMultiply,
1131                         add_gte(instr, *trip_counter),
1132                         add_new_instr(induction_vars.at(i)->Clone()))));
1133     }
1134     return HloInstruction::CreateTuple(tuple_elems);
1135   };
1136 
1137   // Converts `root` into a tuple of the "new" form -- that is, to a tuple with
1138   // shape `new_while_shape` and where the induction variables (but not trip
1139   // counters) are replaced with their unchanging <loop_body_param> values.
1140   auto convert_to_new_form = [&](HloInstruction* old_root,
1141                                  HloParameterInstruction* loop_body_param) {
1142     CHECK(ShapeUtil::Compatible(old_root->shape(), while_shape));
1143     std::vector<HloInstruction*> tuple_elems;
1144 
1145     // In the new form, induction variables come from `init`, everything else
1146     // (including the trip counter if it's not one we created ourselves) comes
1147     // from the `root` tuple unmodified.
1148     for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1149       tuple_elems.push_back(
1150           add_gte((induction_vars.count(i) ? loop_body_param : old_root), i));
1151     }
1152     // If we created a trip counter ourselves, add 1 to it in the next
1153     // iteration.
1154     if (added_trip_counter) {
1155       tuple_elems.push_back(add_binary_op(
1156           new_while_shape.tuple_shapes(*trip_counter), HloOpcode::kAdd,
1157           add_gte(loop_body_param, *trip_counter),
1158           add_new_instr(
1159               HloInstruction::CreateConstant(LiteralUtil::One(elem_ty)))));
1160     }
1161 
1162     return HloInstruction::CreateTuple(tuple_elems);
1163   };
1164 
1165   // Creates a new init tuple, which is the same as the old init tuple except if
1166   // we added a trip counter, it's set to 0.
1167   auto get_new_while_init = [&](HloInstruction* init) {
1168     CHECK(ShapeUtil::Compatible(init->shape(), while_shape));
1169     if (!added_trip_counter) {
1170       return init;
1171     }
1172     std::vector<HloInstruction*> tuple_elems;
1173     for (int64_t i = 0; i < while_shape.tuple_shapes_size(); ++i) {
1174       tuple_elems.push_back(add_gte(init, i));
1175     }
1176     tuple_elems.push_back(add_new_instr(
1177         HloInstruction::CreateConstant(LiteralUtil::Zero(elem_ty))));
1178     return add_new_instr(HloInstruction::CreateTuple(tuple_elems));
1179   };
1180 
1181   std::unique_ptr<HloComputation> new_while_cond =
1182       while_cond->CloneWithReplacementPairs({
1183           while_cond->parameter_instruction(0),
1184           convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1185               0, new_while_shape,
1186               while_cond->parameter_instruction(0)->name()))),
1187       });
1188 
1189   // Creating the new while body proceeds in two steps.  First we convert the
1190   // users of the parameter to the old form.  Then as a second
1191   // CloneWithReplacement operation we convert the root to the new form.  We
1192   // have to do this in two steps because the new root needs to use the new
1193   // param0, and during the first clone operation, only the *old-form* param0 is
1194   // accessible.
1195   //
1196   // We have to add temp_new_while_body to the module because cloning a
1197   // computation touches the module (to get its NameUniquer).
1198   HloComputation* temp_new_while_body =
1199       module->AddEmbeddedComputation(while_body->CloneWithReplacementPairs({
1200           while_body->parameter_instruction(0),
1201           convert_to_old_form(add_new_instr(HloInstruction::CreateParameter(
1202               0, new_while_shape,
1203               while_body->parameter_instruction(0)->name()))),
1204       }));
1205   std::unique_ptr<HloComputation> new_while_body =
1206       temp_new_while_body->CloneWithReplacementPairs({
1207           temp_new_while_body->root_instruction(),
1208           convert_to_new_form(
1209               add_new_instr(temp_new_while_body->root_instruction()->Clone()),
1210               Cast<HloParameterInstruction>(
1211                   temp_new_while_body->parameter_instruction(0))),
1212       });
1213   TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(temp_new_while_body));
1214 
1215   // Create the final while loop, and add any new instructions created to
1216   // `computation`.
1217   new_instrs.clear();
1218   auto* new_while = computation->AddInstruction(HloInstruction::CreateWhile(
1219       new_while_shape,
1220       module->AddEmbeddedComputation(std::move(new_while_cond)),
1221       module->AddEmbeddedComputation(std::move(new_while_body)),
1222       get_new_while_init(while_init)));
1223   CopyBackendConfig(while_op, new_while);
1224   CopyFrontendAttributes(while_op, new_while);
1225   TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
1226       while_op, convert_to_old_form(new_while)));
1227   for (auto& instr : new_instrs) {
1228     computation->AddInstruction(std::move(instr));
1229   }
1230   return new_while;
1231 }
1232 
Run(HloModule * module)1233 StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
1234   XLA_VLOG_LINES(3,
1235                  "WhileLoopSimplifier::Run(), before:\n" + module->ToString());
1236   bool changed = false;
1237 
1238   // Gather all the while ops in our module.  We do this ahead of time so we
1239   // don't have to worry about mutating the lists of computations or
1240   // instructions while we iterate.
1241   std::vector<HloInstruction*> while_ops;
1242   for (auto* comp : module->computations()) {
1243     for (auto* instr : comp->instructions()) {
1244       if (instr->opcode() == HloOpcode::kWhile) {
1245         while_ops.push_back(instr);
1246       }
1247     }
1248   }
1249 
1250   for (HloInstruction* while_op : while_ops) {
1251     // We can't remove while loops that contain send/recv nodes, because we rely
1252     // on the particular loop structure around the node matching on the send and
1253     // recv sides.  Other while simplifications require us to remove the loop
1254     // and replace it with a new one, so we can't do that either.
1255     if (ContainsInstrWithOpcode(while_op->while_body(),
1256                                 {HloOpcode::kSend, HloOpcode::kSendDone,
1257                                  HloOpcode::kRecv, HloOpcode::kRecvDone}) ||
1258         ContainsInstrWithOpcode(while_op->while_condition(),
1259                                 {HloOpcode::kSend, HloOpcode::kSendDone,
1260                                  HloOpcode::kRecv, HloOpcode::kRecvDone})) {
1261       VLOG(2) << "Not attempting to simplify while loop because it contains a "
1262                  "send/recv node: "
1263               << while_op->ToShortString();
1264       continue;
1265     }
1266 
1267     TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
1268     changed |= result;
1269 
1270     TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
1271     changed |= result;
1272 
1273     if (result) {
1274       // Don't continue simplifying after successfully removing the while loop
1275       // -- that would result in use-after-free nastiness.
1276       continue;
1277     }
1278 
1279     // TODO(b/119281462): Cowardly refuse to perform any of the following
1280     // optimizations in the presence of kDomain instructions.  It seems that
1281     // modifying a while loop's tuple doesn't work when kDomain is present.
1282     if (ContainsInstrWithOpcode(while_op->while_body(), {HloOpcode::kDomain}) ||
1283         ContainsInstrWithOpcode(while_op->while_condition(),
1284                                 {HloOpcode::kDomain})) {
1285       continue;
1286     }
1287 
1288     // Each of the optimizations below modifies the while loop itself if it's
1289     // successful, meaning that `while_op` is no longer valid after one of these
1290     // transformations returns true.
1291 
1292     TF_ASSIGN_OR_RETURN(result, TryRemoveRepeatedWhileTupleIndices(while_op));
1293     changed |= result;
1294     if (result) {
1295       continue;
1296     }
1297 
1298     TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
1299     changed |= result;
1300     if (result) {
1301       continue;
1302     }
1303 
1304     TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
1305 
1306     changed |= result;
1307     if (result) {
1308       continue;
1309     }
1310 
1311     TF_ASSIGN_OR_RETURN(result, TryRemoveConstantParams(while_op));
1312     changed |= result;
1313     if (result) {
1314       continue;
1315     }
1316 
1317     bool merged_induction_vars = false;
1318     // Notably missing from this list are S16 and U16.  These don't currently
1319     // work because S/U16 literals are not implemented.
1320     for (auto elem_ty : {S8, U8, S32, U32, S64, U64}) {
1321       TF_ASSIGN_OR_RETURN(auto* new_while_op,
1322                           TryMergeInductionVariables(while_op, elem_ty));
1323       if (new_while_op) {
1324         while_op = new_while_op;
1325         changed = true;
1326         merged_induction_vars = true;
1327       }
1328     }
1329     if (merged_induction_vars) {
1330       continue;
1331     }
1332   }
1333 
1334   XLA_VLOG_LINES(3,
1335                  "WhileLoopSimplifier::Run(), after:\n" + module->ToString());
1336   return changed;
1337 }
1338 
1339 }  // namespace xla
1340