• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17 
18 #include "absl/base/casts.h"
19 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25 
26 namespace xla {
27 
28 using absl::nullopt;
29 using absl::optional;
30 namespace m = match;
31 
32 // Finds and returns the non-constant operand in instr.
33 //
34 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)35 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
36   const HloInstruction* result = nullptr;
37   for (const HloInstruction* operand : instr->operands()) {
38     if (!operand->IsConstant()) {
39       if (result != nullptr) {
40         CHECK_EQ(result, operand);
41       }
42       result = operand;
43     }
44   }
45   CHECK_NE(result, nullptr);
46   return result;
47 }
48 
49 // If all of instr's operands are either constants or have the form
50 //   get-tuple-element(gte_operand, N)
51 // for the same value N, returns N.  Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)52 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
53                                           const HloInstruction* gte_operand) {
54   VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
55           << gte_operand->ToString() << ")";
56 
57   // Among the operands of `instr`, find one that is a get-tuple-element op or
58   // the one that is copy fed by a get-tuple-element.
59   auto gte_it =
60       absl::c_find_if(instr->operands(), [](const HloInstruction* instr) {
61         return (instr->opcode() == HloOpcode::kGetTupleElement) ||
62                (instr->opcode() == HloOpcode::kCopy &&
63                 instr->operand(0)->opcode() == HloOpcode::kGetTupleElement);
64       });
65   if (gte_it == instr->operands().end()) {
66     VLOG(2) << "instr does not have a gte operand.";
67     return nullopt;
68   }
69 
70   // All operands of `instr` must be either constants or of the form
71   //   get-tuple-element(gte_operand, tuple_idx)
72   // for the same value tuple_idx.
73   int64_t tuple_idx = (*gte_it)->tuple_index();
74   for (const HloInstruction* operand : instr->operands()) {
75     if (!Match(operand, m::Constant()) &&
76         !Match(operand,
77                m::GetTupleElement(m::Op().Is(gte_operand), tuple_idx))) {
78       VLOG(2)
79           << "instr uses something other than a constant or gte(gte_operand, "
80           << tuple_idx << "): " << operand->ToString();
81       return nullopt;
82     }
83   }
84   return tuple_idx;
85 }
86 
87 // The below function identifies a subset of all possible auxiliary
88 // induction variables (AIV). Specifically, candidates are gtes, e.g.,
89 // gte(param0, N)
90 // The function checks if the loop body plumbs the AIV
91 // through the same tuple index at root, and that ops involving AIV
92 // involve constants.
93 //   op2 = op(constants, gte(param0, N), constants)
94 //   op3 = op(constants, f(op2, gte(param0, N), constants)
95 //   op4 = op(constants, f(op3, constants)
96 //   root = tuple(..., op4, ...)
97 // Further, the ops are restricted to basic math ops (+,-,*,/).
98 // Finally, loop invariant GTEs are excluded from AIVs.
99 // We can expand the ops category/nature of AIVs as needed.
GetAuxiliaryLoopInductionVars(const HloInstruction * while_op)100 std::vector<const HloInstruction*> GetAuxiliaryLoopInductionVars(
101     const HloInstruction* while_op) {
102   std::vector<const HloInstruction*> aux_ind_gte;
103   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
104   auto* while_body = while_op->while_body();
105   auto* while_body_param = while_body->parameter_instruction(0);
106   VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString();
107   VLOG(2) << "the parameter instr:" << while_body_param->ToShortString();
108   VLOG(2) << "the parameter user count:" << while_body_param->users().size();
109   if (while_body_param == nullptr) return aux_ind_gte;
110 
111   // candidates_pairs = pair<inst, inst>(
112   //   operands of the root while body,
113   //   GTE only operands that index into the same position in the parameter)
114   // for each candidate_pair (x, y)
115   //  find all paths between x and y,
116   //  each paths should satisfy the above listed criterion
117   //  index that x and y used is added as a aux variable index
118   std::map<int64, const HloInstruction*> extractions;
119   for (const HloInstruction* indx_instr : while_body_param->users()) {
120     if (indx_instr->opcode() != HloOpcode::kGetTupleElement) {
121       continue;
122     }
123     auto it = extractions.find(indx_instr->tuple_index());
124     // if we find two extractions at the same index, we ignore such
125     // a candidate
126     if (it != extractions.end()) {
127       it->second = nullptr;
128       VLOG(2) << "two extractions at same index:" << indx_instr->ToString();
129     } else {
130       extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr));
131       VLOG(2) << "inserting extraction :" << indx_instr->ToString();
132     }
133   }
134   VLOG(2) << "total extractions size:" << extractions.size() << std::endl;
135   if (extractions.empty()) {
136     return aux_ind_gte;
137   }
138 
139   auto* while_body_root = while_body->root_instruction();
140   if (while_body_root->opcode() != HloOpcode::kTuple) {
141     VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString();
142     return aux_ind_gte;
143   }
144   int64_t index = -1;
145   std::map<int64, const HloInstruction*> insertions;
146   for (const HloInstruction* operand : while_body_root->operands()) {
147     index++;
148     if (!operand->IsConstant()) {
149       auto it = insertions.find(index);
150       if (it != insertions.end()) {
151         it->second = nullptr;
152         VLOG(2) << "two insertions at same index:" << operand->ToString();
153       } else {
154         insertions.insert(std::make_pair(index, operand));
155         VLOG(2) << "inserting insertions:" << operand->ToString();
156       }
157     }
158   }
159   if (insertions.empty()) {
160     return aux_ind_gte;
161   }
162 
163   std::map<int64, std::pair<const HloInstruction*, const HloInstruction*>>
164       candidate_pairs;
165   for (; index >= 0; --index) {
166     const HloInstruction *ext, *inst;
167     ext = (extractions.find(index) != extractions.end())
168               ? extractions.find(index)->second
169               : nullptr;
170     inst = (insertions.find(index) != insertions.end())
171                ? insertions.find(index)->second
172                : nullptr;
173     if (ext != nullptr && inst != nullptr) {
174       // Filter out trivial aux, i.e., extract directly to an insert.
175       if (ext != inst) {
176         candidate_pairs.insert(
177             std::make_pair(index, std::make_pair(ext, inst)));
178       }
179     }
180   }
181   VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl;
182 
183   // Passed to ReachabilityMap to decide the type of produce-consumer edges
184   // along the reachability path.
185   const auto add_dependencies = [](const HloInstruction* hlo,
186                                    std::vector<HloInstruction*>* inputs) {
187     HloInstruction* non_const_operand = nullptr;
188     int num_non_constants = 0;
189     for (HloInstruction* operand : hlo->operands()) {
190       if (!operand->IsConstant()) {
191         num_non_constants++;
192         non_const_operand = operand;
193       }
194     }
195     if (num_non_constants == 1 &&
196         (hlo->opcode() == HloOpcode::kGetTupleElement ||
197          hlo->opcode() == HloOpcode::kAdd ||
198          hlo->opcode() == HloOpcode::kMultiply ||
199          hlo->opcode() == HloOpcode::kDivide ||
200          hlo->opcode() == HloOpcode::kSubtract)) {
201       inputs->push_back(non_const_operand);
202     }
203   };
204 
205   std::unique_ptr<HloReachabilityMap> hrm =
206       HloReachabilityMap::BuildWithRestrictions(
207           while_body,
208           absl::FunctionRef<void(const HloInstruction* hlo,
209                                  std::vector<HloInstruction*>* inputs)>(
210               add_dependencies));
211 
212   for (auto candidates : candidate_pairs) {
213     VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString()
214             << "*************" << (candidates.second.second)->ToString()
215             << std::endl;
216     if (hrm->IsReachable(candidates.second.first, candidates.second.second)) {
217       aux_ind_gte.push_back(candidates.second.first);
218       VLOG(2) << "YES";
219     } else {
220       VLOG(2) << "NO";
221     }
222   }
223   VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size();
224   return aux_ind_gte;
225 }
226 
227 // Tries to get the tuple index of the induction variable of a while loop.
228 //
229 // Checks that the loop condition and body both plumb the induction variable
230 // through the same tuple index, and that they both apply exactly one op to the
231 // induction variable before  deciding whether to do another loop iteration (in
232 // the loop condition's case) or packing the induction variable into the result
233 // tuple (in the loop body's case).
234 //
235 // Specifically, checks that the loop condition has structure
236 //
237 //   root = op(constants, get-tuple-elem(param0, N), constants)
238 //
239 // and the loop body has the structure
240 //
241 //   inc = op(constants, get-tuple-elem(param0, N), constants)
242 //   root = tuple(..., inc, ...)  // inc is N'th operand of tuple().
243 //
244 // If so, returns N.  Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)245 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
246   CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
247   VLOG(2) << "Finding induction variable for loop "
248           << while_op->ToShortString();
249 
250   // The while_cond computation should have the form
251   //
252   //   while_cond_root =
253   //       op(constants, get-tuple-elem(while_cond_param, N), constants).
254   //
255   // If it does, set indvar_tuple_idx to N.
256   auto* while_cond = while_op->while_condition();
257   auto* while_cond_root = while_cond->root_instruction();
258   auto* while_cond_param = while_cond->parameter_instruction(0);
259   optional<int64> indvar_tuple_idx =
260       GetGTEOperandIndex(while_cond_root, while_cond_param);
261   if (!indvar_tuple_idx) {
262     VLOG(2) << "Induction variable not found in loop condition: "
263             << while_cond->root_instruction()->ToString();
264     return nullopt;
265   }
266 
267   // The while_body computation should have the form
268   //
269   //   while_body_inc =
270   //       op(constants, get-tuple-elem(while_body_param, N), constants)
271   //   while_body_root = tuple(..., while_body_inc, ...)
272   //
273   // where while_body_inc is operand N of while_body_root.
274   auto* while_body = while_op->while_body();
275   auto* while_body_root = while_body->root_instruction();
276   if (while_body_root->opcode() != HloOpcode::kTuple) {
277     VLOG(2) << "While body's root is not a tuple instruction: "
278             << while_body_root->ToString();
279     return nullopt;
280   }
281 
282   auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
283   auto* while_body_param = while_body->parameter_instruction(0);
284   optional<int64> while_body_indvar_tuple_idx =
285       GetGTEOperandIndex(while_body_inc, while_body_param);
286   if (!while_body_indvar_tuple_idx) {
287     VLOG(2)
288         << "Induction variable not found in while body increment instruction: "
289         << while_body_inc->ToString();
290     return nullopt;
291   }
292   if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
293     VLOG(2) << "Tuple index of induction variable does not match between loop "
294                "condition ("
295             << *indvar_tuple_idx << ") and while body ("
296             << *while_body_indvar_tuple_idx << ")";
297     return nullopt;
298   }
299 
300   // Finally, check that the while loop's initial value is a tuple with enough
301   // elements.
302   auto* while_init = while_op->operand(0);
303   if (while_init->opcode() != HloOpcode::kTuple) {
304     VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
305     return nullopt;
306   }
307 
308   VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
309   return indvar_tuple_idx;
310 }
311 
312 // Converts the given literal to a scalar int64, if possible.
313 //
314 // Fails if the literal is not an integral type or if the value it contains
315 // cannot be represented in an int64.
LiteralAsScalarInt64(const Literal & l)316 static optional<int64> LiteralAsScalarInt64(const Literal& l) {
317   if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
318     VLOG(2) << "literal is not an effective scalar: " << l.ToString();
319     return nullopt;
320   }
321   switch (l.shape().element_type()) {
322     case S8:
323       return l.GetFirstElement<int8>();
324     case S16:
325       return l.GetFirstElement<int16>();
326     case S32:
327       return l.GetFirstElement<int32>();
328     case S64:
329       return l.GetFirstElement<int64>();
330     case U8:
331       return l.GetFirstElement<uint8>();
332     case U16:
333       return l.GetFirstElement<uint16>();
334     case U32:
335       return l.GetFirstElement<uint32>();
336     case U64: {
337       uint64 v = l.GetFirstElement<uint64>();
338       if (v > static_cast<uint64>(std::numeric_limits<int64>::max())) {
339         VLOG(2) << "uint64 literal is out of range for int64: " << v;
340         return nullopt;
341       }
342       return v;
343     }
344     default:
345       VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
346       return nullopt;
347   }
348 }
349 
350 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64_t a,int64_t b)351 optional<int64> CheckedAdd(int64_t a, int64_t b) {
352   // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
353   // different sign, see Hacker's Delignt 2nd Ed. pp 28.
354   uint64 aa = absl::bit_cast<uint64>(a);
355   uint64 bb = absl::bit_cast<uint64>(b);
356   int64_t result = absl::bit_cast<int64>(aa + bb);
357   if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
358     return nullopt;
359   }
360   return result;
361 }
362 
363 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64_t a,int64_t b)364 optional<int64> CheckedSubtract(int64_t a, int64_t b) {
365   uint64 aa = absl::bit_cast<uint64>(a);
366   uint64 bb = absl::bit_cast<uint64>(b);
367   int64_t result = absl::bit_cast<int64>(aa - bb);
368   // Overflow occurred iff `a` and `b` have different signs and the sign of
369   // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
370   if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
371     return nullopt;
372   }
373   return result;
374 }
375 
376 // Check if
377 //  - `i` is initialized to a scalar constant K (namely, `indvar_init`),
378 //  - the while condition does `i < N` or `i <= N`, and
379 //  - the while body does `i++`.
380 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64_t indvar_tuple_idx,const Literal & indvar_init)381 static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
382                                                  int64_t indvar_tuple_idx,
383                                                  const Literal& indvar_init) {
384   // First, find the scalar constant K that `i` is initialized to.
385   optional<int64> indvar_init_val = LiteralAsScalarInt64(indvar_init);
386   if (!indvar_init_val) {
387     VLOG(2) << "Pattern-match failed: induction variable init is not a "
388                "constant scalar representable as an int64: "
389             << indvar_init.ToString();
390     return nullopt;
391   }
392 
393   // Check that `i` goes as `i++` in the while body.
394   //
395   // TODO(jlebar): We could also handle i-- and other idioms.
396   auto* while_body = while_op->while_body();
397   auto* while_body_indvar_update =
398       while_body->root_instruction()->operand(indvar_tuple_idx);
399   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
400   if (!Match(while_body_indvar_update,
401              m::AddAnyOrder(m::Op().Is(while_body_indvar),
402                             m::ConstantEffectiveScalar(1)))) {
403     VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
404             << while_body_indvar_update->ToString();
405     return nullopt;
406   }
407 
408   // Check that we do op(i, N) or op(N, i) as the while condition.  Capture the
409   // value N.
410   auto* while_cond = while_op->while_condition();
411   auto* while_cond_root = while_cond->root_instruction();
412   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
413   HloInstruction* while_cond_bound = nullptr;
414   if (!Match(while_cond_root,
415              m::Op().WithBinaryOperandsAnyOrder(
416                  m::Op().Is(while_cond_indvar),
417                  m::ConstantEffectiveScalar(&while_cond_bound)))) {
418     VLOG(2) << "Pattern-match failed: while condition is not of the form "
419                "op(i, N) or op(N, i).";
420     return nullopt;
421   }
422   // Note: If this succeeds, the constant `N` is representable as an int64 --
423   // that is, if it's an XLA U64, it fits within an int64.
424   optional<int64> while_cond_bound_val =
425       LiteralAsScalarInt64(while_cond_bound->literal());
426   if (!while_cond_bound_val) {
427     VLOG(2) << "Pattern-match failed: while condition induction variable is "
428                "not a constant scalar representable as an int64.";
429     return nullopt;
430   }
431 
432   // Handle `i = K; i < N; ++i`.
433   if (Match(while_cond_root,
434             m::Op()
435                 .WithComparisonDirection(ComparisonDirection::kLt)
436                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
437     VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
438             << while_cond_root->ToString();
439     optional<int64> trips =
440         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
441     if (trips) {
442       return std::max(int64{0}, *trips);
443     } else {
444       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
445       return nullopt;
446     }
447   }
448 
449   // Handle `i = K; i <= N; ++i`.
450   if (Match(while_cond_root,
451             m::Op()
452                 .WithComparisonDirection(ComparisonDirection::kLe)
453                 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
454     VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
455             << while_cond_root->ToString();
456     optional<int64> trips =
457         CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
458     if (!trips) {
459       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
460       return nullopt;
461     }
462     trips = CheckedAdd(*trips, 1);
463     if (!trips) {
464       VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
465       return nullopt;
466     }
467     return std::max<int64>(0, *trips);
468   }
469 
470   VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
471           << while_cond_root->ToString();
472   return nullopt;
473 }
474 
ComputeWhileLoopTripCount(HloInstruction * while_op,int64_t max_brute_force_iters)475 optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
476                                           int64_t max_brute_force_iters) {
477   VLOG(2) << "Getting trip count for loop " << while_op->ToString();
478 
479   // The loop's induction variable is found at
480   //
481   //   get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
482   //
483   // where comp is while_op->while_body() or while_op->while_condition().
484   optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
485   if (!indvar_tuple_idx) {
486     return nullopt;
487   }
488 
489   // Now that we know the index of the induction variable, we can we can try to
490   // compute how many times the loop executes.  Start by computing the induction
491   // variable's initial value.
492   HloEvaluator evaluator(/*max_loop_iterations=*/0);
493   auto* while_init = while_op->mutable_operand(0);
494   auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
495   StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
496   if (!indvar_init_result.ok()) {
497     VLOG(2) << "Couldn't evaluate induction variable init, "
498             << indvar_init_result.status() << ", " << indvar_init->ToString();
499     return nullopt;
500   }
501   Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
502 
503   // First, try to pattern-match.
504   if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
505                                                   indvar_iter_val)) {
506     return trip_count;
507   }
508 
509   // If our pattern-match failed, try brute-forcing the loop trip count.
510   auto* while_body = while_op->while_body();
511   auto* while_body_indvar_update =
512       while_body->root_instruction()->operand(*indvar_tuple_idx);
513   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
514 
515   auto* while_cond = while_op->while_condition();
516   auto* while_cond_root = while_cond->root_instruction();
517   auto* while_cond_indvar = NonConstantOperand(while_cond_root);
518 
519   for (int64_t trip_count = 0; trip_count != max_brute_force_iters + 1;
520        ++trip_count) {
521     StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
522         while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
523     if (!result.ok()) {
524       VLOG(2) << "Couldn't evaluate while cond: " << result.status();
525       return nullopt;
526     }
527     if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
528       VLOG(2) << "Loop has static trip count of " << trip_count;
529       return trip_count;
530     }
531 
532     // Calculate the value of the induction variable after one iteration of the
533     // loop, and check whether the while condition is true with this new value.
534     StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
535         while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
536     if (!indvar_next_result.ok()) {
537       VLOG(2) << "Couldn't evaluate induction variable update: "
538               << indvar_next_result.status();
539       return nullopt;
540     }
541     indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
542   }
543 
544   VLOG(2) << "Loop has unknown trip count.";
545   return nullopt;
546 }
547 
548 // If the only user of this instruction is a get-tuple-element, return that
549 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
550 // get a false negative if there are several copies of the same GTE, or there
551 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)552 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
553   if (inst->user_count() != 1) {
554     return nullptr;
555   }
556 
557   HloInstruction* user = inst->users().back();
558   if (user->opcode() != HloOpcode::kGetTupleElement) {
559     return nullptr;
560   }
561   return user;
562 }
563 
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)564 optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
565   // If we know the exact trip count, it's also the upper bound.
566   auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
567   if (exact_trip_count) {
568     VLOG(2) << "Loop has exact trip count.";
569     return exact_trip_count;
570   }
571 
572   // There is one more case we know how to handle. If the loop condition only
573   // looks at one element of the tuple, and the loop body sets this element to a
574   // constant, there are two options:
575   // 1) Evaluating the condition on this constant returns true. In this case,
576   // the loop either executes 0 times, or is an infinite loop, depending on the
577   // init value.
578   // 2) Evaluating the condition on this constant returns false. In this case,
579   // the loop executes 0 or 1 times, depending on the init value. This means
580   // that, regardless of the init value, the upper bound on the trip count is 1.
581 
582   // Check whether the condition depends on a single parameter, and find out
583   // which.
584   auto* while_cond = while_op->while_condition();
585   auto* while_cond_param = while_cond->parameter_instruction(0);
586   auto* cond_gte = GetOnlyGTE(while_cond_param);
587   if (!cond_gte) {
588     VLOG(2) << "Induction variable not found in loop condition: "
589             << while_cond->root_instruction()->ToString();
590     return nullopt;
591   }
592 
593   // Now check whether this gets set to a constant by the while body.
594   auto* while_body = while_op->while_body();
595   auto* while_body_root = while_body->root_instruction();
596   if (while_body_root->opcode() != HloOpcode::kTuple) {
597     VLOG(3) << "While body's root is not a tuple instruction: "
598             << while_body_root->ToString();
599     return nullopt;
600   }
601 
602   int64_t indvar_index = cond_gte->tuple_index();
603   auto* while_body_indvar = while_body_root->operand(indvar_index);
604   if (while_body_indvar->opcode() != HloOpcode::kConstant) {
605     VLOG(3) << "While body does not set the IV to a constant: "
606             << while_body_indvar->ToString();
607     return nullopt;
608   }
609 
610   // We have a constant. Evaluate the condition on this constant.
611   HloEvaluator evaluator(/*max_loop_iterations=*/0);
612   Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
613   TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
614                                   /*dest_shape_index=*/{indvar_index},
615                                   /*src_shape_index=*/{}));
616   StatusOr<Literal> eval_result =
617       evaluator.Evaluate(*while_cond, {std::move(fake_input)});
618 
619   if (!eval_result.ok()) {
620     VLOG(2) << "Couldn't evaluate while loop condition.";
621     return nullopt;
622   }
623 
624   Literal cond_result_pred = std::move(eval_result.ValueOrDie());
625   CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
626                                       ShapeUtil::MakeShape(PRED, {})));
627 
628   // Per the explanation above, if the evaluated condition returns false, the
629   // loop executes at most once.
630   bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
631   if (!cond_returns_true) {
632     VLOG(2) << "Upper bound on the trip count is 1";
633     return 1;
634   }
635 
636   VLOG(2) << "Loop has no known upper bound on the trip count.";
637   return nullopt;
638 }
639 
640 }  // namespace xla
641