1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17
18 #include "absl/base/casts.h"
19 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25
26 namespace xla {
27
28 using absl::nullopt;
29 using absl::optional;
30 namespace m = match;
31
32 // Finds and returns the non-constant operand in instr.
33 //
34 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)35 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
36 const HloInstruction* result = nullptr;
37 for (const HloInstruction* operand : instr->operands()) {
38 if (!operand->IsConstant()) {
39 if (result != nullptr) {
40 CHECK_EQ(result, operand);
41 }
42 result = operand;
43 }
44 }
45 CHECK_NE(result, nullptr);
46 return result;
47 }
48
49 // If all of instr's operands are either constants or have the form
50 // get-tuple-element(gte_operand, N)
51 // for the same value N, returns N. Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)52 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
53 const HloInstruction* gte_operand) {
54 VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
55 << gte_operand->ToString() << ")";
56
57 // Among the operands of `instr`, find one that is a get-tuple-element op or
58 // the one that is copy fed by a get-tuple-element.
59 auto gte_it =
60 absl::c_find_if(instr->operands(), [](const HloInstruction* instr) {
61 return (instr->opcode() == HloOpcode::kGetTupleElement) ||
62 (instr->opcode() == HloOpcode::kCopy &&
63 instr->operand(0)->opcode() == HloOpcode::kGetTupleElement);
64 });
65 if (gte_it == instr->operands().end()) {
66 VLOG(2) << "instr does not have a gte operand.";
67 return nullopt;
68 }
69
70 // All operands of `instr` must be either constants or of the form
71 // get-tuple-element(gte_operand, tuple_idx)
72 // for the same value tuple_idx.
73 int64_t tuple_idx = (*gte_it)->tuple_index();
74 for (const HloInstruction* operand : instr->operands()) {
75 if (!Match(operand, m::Constant()) &&
76 !Match(operand,
77 m::GetTupleElement(m::Op().Is(gte_operand), tuple_idx))) {
78 VLOG(2)
79 << "instr uses something other than a constant or gte(gte_operand, "
80 << tuple_idx << "): " << operand->ToString();
81 return nullopt;
82 }
83 }
84 return tuple_idx;
85 }
86
87 // The below function identifies a subset of all possible auxiliary
88 // induction variables (AIV). Specifically, candidates are gtes, e.g.,
89 // gte(param0, N)
90 // The function checks if the loop body plumbs the AIV
91 // through the same tuple index at root, and that ops involving AIV
92 // involve constants.
93 // op2 = op(constants, gte(param0, N), constants)
94 // op3 = op(constants, f(op2, gte(param0, N), constants)
95 // op4 = op(constants, f(op3, constants)
96 // root = tuple(..., op4, ...)
97 // Further, the ops are restricted to basic math ops (+,-,*,/).
98 // Finally, loop invariant GTEs are excluded from AIVs.
99 // We can expand the ops category/nature of AIVs as needed.
GetAuxiliaryLoopInductionVars(const HloInstruction * while_op)100 std::vector<const HloInstruction*> GetAuxiliaryLoopInductionVars(
101 const HloInstruction* while_op) {
102 std::vector<const HloInstruction*> aux_ind_gte;
103 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
104 auto* while_body = while_op->while_body();
105 auto* while_body_param = while_body->parameter_instruction(0);
106 VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString();
107 VLOG(2) << "the parameter instr:" << while_body_param->ToShortString();
108 VLOG(2) << "the parameter user count:" << while_body_param->users().size();
109 if (while_body_param == nullptr) return aux_ind_gte;
110
111 // candidates_pairs = pair<inst, inst>(
112 // operands of the root while body,
113 // GTE only operands that index into the same position in the parameter)
114 // for each candidate_pair (x, y)
115 // find all paths between x and y,
116 // each paths should satisfy the above listed criterion
117 // index that x and y used is added as a aux variable index
118 std::map<int64, const HloInstruction*> extractions;
119 for (const HloInstruction* indx_instr : while_body_param->users()) {
120 if (indx_instr->opcode() != HloOpcode::kGetTupleElement) {
121 continue;
122 }
123 auto it = extractions.find(indx_instr->tuple_index());
124 // if we find two extractions at the same index, we ignore such
125 // a candidate
126 if (it != extractions.end()) {
127 it->second = nullptr;
128 VLOG(2) << "two extractions at same index:" << indx_instr->ToString();
129 } else {
130 extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr));
131 VLOG(2) << "inserting extraction :" << indx_instr->ToString();
132 }
133 }
134 VLOG(2) << "total extractions size:" << extractions.size() << std::endl;
135 if (extractions.empty()) {
136 return aux_ind_gte;
137 }
138
139 auto* while_body_root = while_body->root_instruction();
140 if (while_body_root->opcode() != HloOpcode::kTuple) {
141 VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString();
142 return aux_ind_gte;
143 }
144 int64_t index = -1;
145 std::map<int64, const HloInstruction*> insertions;
146 for (const HloInstruction* operand : while_body_root->operands()) {
147 index++;
148 if (!operand->IsConstant()) {
149 auto it = insertions.find(index);
150 if (it != insertions.end()) {
151 it->second = nullptr;
152 VLOG(2) << "two insertions at same index:" << operand->ToString();
153 } else {
154 insertions.insert(std::make_pair(index, operand));
155 VLOG(2) << "inserting insertions:" << operand->ToString();
156 }
157 }
158 }
159 if (insertions.empty()) {
160 return aux_ind_gte;
161 }
162
163 std::map<int64, std::pair<const HloInstruction*, const HloInstruction*>>
164 candidate_pairs;
165 for (; index >= 0; --index) {
166 const HloInstruction *ext, *inst;
167 ext = (extractions.find(index) != extractions.end())
168 ? extractions.find(index)->second
169 : nullptr;
170 inst = (insertions.find(index) != insertions.end())
171 ? insertions.find(index)->second
172 : nullptr;
173 if (ext != nullptr && inst != nullptr) {
174 // Filter out trivial aux, i.e., extract directly to an insert.
175 if (ext != inst) {
176 candidate_pairs.insert(
177 std::make_pair(index, std::make_pair(ext, inst)));
178 }
179 }
180 }
181 VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl;
182
183 // Passed to ReachabilityMap to decide the type of produce-consumer edges
184 // along the reachability path.
185 const auto add_dependencies = [](const HloInstruction* hlo,
186 std::vector<HloInstruction*>* inputs) {
187 HloInstruction* non_const_operand = nullptr;
188 int num_non_constants = 0;
189 for (HloInstruction* operand : hlo->operands()) {
190 if (!operand->IsConstant()) {
191 num_non_constants++;
192 non_const_operand = operand;
193 }
194 }
195 if (num_non_constants == 1 &&
196 (hlo->opcode() == HloOpcode::kGetTupleElement ||
197 hlo->opcode() == HloOpcode::kAdd ||
198 hlo->opcode() == HloOpcode::kMultiply ||
199 hlo->opcode() == HloOpcode::kDivide ||
200 hlo->opcode() == HloOpcode::kSubtract)) {
201 inputs->push_back(non_const_operand);
202 }
203 };
204
205 std::unique_ptr<HloReachabilityMap> hrm =
206 HloReachabilityMap::BuildWithRestrictions(
207 while_body,
208 absl::FunctionRef<void(const HloInstruction* hlo,
209 std::vector<HloInstruction*>* inputs)>(
210 add_dependencies));
211
212 for (auto candidates : candidate_pairs) {
213 VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString()
214 << "*************" << (candidates.second.second)->ToString()
215 << std::endl;
216 if (hrm->IsReachable(candidates.second.first, candidates.second.second)) {
217 aux_ind_gte.push_back(candidates.second.first);
218 VLOG(2) << "YES";
219 } else {
220 VLOG(2) << "NO";
221 }
222 }
223 VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size();
224 return aux_ind_gte;
225 }
226
227 // Tries to get the tuple index of the induction variable of a while loop.
228 //
229 // Checks that the loop condition and body both plumb the induction variable
230 // through the same tuple index, and that they both apply exactly one op to the
231 // induction variable before deciding whether to do another loop iteration (in
232 // the loop condition's case) or packing the induction variable into the result
233 // tuple (in the loop body's case).
234 //
235 // Specifically, checks that the loop condition has structure
236 //
237 // root = op(constants, get-tuple-elem(param0, N), constants)
238 //
239 // and the loop body has the structure
240 //
241 // inc = op(constants, get-tuple-elem(param0, N), constants)
242 // root = tuple(..., inc, ...) // inc is N'th operand of tuple().
243 //
244 // If so, returns N. Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)245 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
246 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
247 VLOG(2) << "Finding induction variable for loop "
248 << while_op->ToShortString();
249
250 // The while_cond computation should have the form
251 //
252 // while_cond_root =
253 // op(constants, get-tuple-elem(while_cond_param, N), constants).
254 //
255 // If it does, set indvar_tuple_idx to N.
256 auto* while_cond = while_op->while_condition();
257 auto* while_cond_root = while_cond->root_instruction();
258 auto* while_cond_param = while_cond->parameter_instruction(0);
259 optional<int64> indvar_tuple_idx =
260 GetGTEOperandIndex(while_cond_root, while_cond_param);
261 if (!indvar_tuple_idx) {
262 VLOG(2) << "Induction variable not found in loop condition: "
263 << while_cond->root_instruction()->ToString();
264 return nullopt;
265 }
266
267 // The while_body computation should have the form
268 //
269 // while_body_inc =
270 // op(constants, get-tuple-elem(while_body_param, N), constants)
271 // while_body_root = tuple(..., while_body_inc, ...)
272 //
273 // where while_body_inc is operand N of while_body_root.
274 auto* while_body = while_op->while_body();
275 auto* while_body_root = while_body->root_instruction();
276 if (while_body_root->opcode() != HloOpcode::kTuple) {
277 VLOG(2) << "While body's root is not a tuple instruction: "
278 << while_body_root->ToString();
279 return nullopt;
280 }
281
282 auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
283 auto* while_body_param = while_body->parameter_instruction(0);
284 optional<int64> while_body_indvar_tuple_idx =
285 GetGTEOperandIndex(while_body_inc, while_body_param);
286 if (!while_body_indvar_tuple_idx) {
287 VLOG(2)
288 << "Induction variable not found in while body increment instruction: "
289 << while_body_inc->ToString();
290 return nullopt;
291 }
292 if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
293 VLOG(2) << "Tuple index of induction variable does not match between loop "
294 "condition ("
295 << *indvar_tuple_idx << ") and while body ("
296 << *while_body_indvar_tuple_idx << ")";
297 return nullopt;
298 }
299
300 // Finally, check that the while loop's initial value is a tuple with enough
301 // elements.
302 auto* while_init = while_op->operand(0);
303 if (while_init->opcode() != HloOpcode::kTuple) {
304 VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
305 return nullopt;
306 }
307
308 VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
309 return indvar_tuple_idx;
310 }
311
312 // Converts the given literal to a scalar int64, if possible.
313 //
314 // Fails if the literal is not an integral type or if the value it contains
315 // cannot be represented in an int64.
LiteralAsScalarInt64(const Literal & l)316 static optional<int64> LiteralAsScalarInt64(const Literal& l) {
317 if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
318 VLOG(2) << "literal is not an effective scalar: " << l.ToString();
319 return nullopt;
320 }
321 switch (l.shape().element_type()) {
322 case S8:
323 return l.GetFirstElement<int8>();
324 case S16:
325 return l.GetFirstElement<int16>();
326 case S32:
327 return l.GetFirstElement<int32>();
328 case S64:
329 return l.GetFirstElement<int64>();
330 case U8:
331 return l.GetFirstElement<uint8>();
332 case U16:
333 return l.GetFirstElement<uint16>();
334 case U32:
335 return l.GetFirstElement<uint32>();
336 case U64: {
337 uint64 v = l.GetFirstElement<uint64>();
338 if (v > static_cast<uint64>(std::numeric_limits<int64>::max())) {
339 VLOG(2) << "uint64 literal is out of range for int64: " << v;
340 return nullopt;
341 }
342 return v;
343 }
344 default:
345 VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
346 return nullopt;
347 }
348 }
349
350 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64_t a,int64_t b)351 optional<int64> CheckedAdd(int64_t a, int64_t b) {
352 // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
353 // different sign, see Hacker's Delignt 2nd Ed. pp 28.
354 uint64 aa = absl::bit_cast<uint64>(a);
355 uint64 bb = absl::bit_cast<uint64>(b);
356 int64_t result = absl::bit_cast<int64>(aa + bb);
357 if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
358 return nullopt;
359 }
360 return result;
361 }
362
363 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64_t a,int64_t b)364 optional<int64> CheckedSubtract(int64_t a, int64_t b) {
365 uint64 aa = absl::bit_cast<uint64>(a);
366 uint64 bb = absl::bit_cast<uint64>(b);
367 int64_t result = absl::bit_cast<int64>(aa - bb);
368 // Overflow occurred iff `a` and `b` have different signs and the sign of
369 // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
370 if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
371 return nullopt;
372 }
373 return result;
374 }
375
376 // Check if
377 // - `i` is initialized to a scalar constant K (namely, `indvar_init`),
378 // - the while condition does `i < N` or `i <= N`, and
379 // - the while body does `i++`.
380 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64_t indvar_tuple_idx,const Literal & indvar_init)381 static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
382 int64_t indvar_tuple_idx,
383 const Literal& indvar_init) {
384 // First, find the scalar constant K that `i` is initialized to.
385 optional<int64> indvar_init_val = LiteralAsScalarInt64(indvar_init);
386 if (!indvar_init_val) {
387 VLOG(2) << "Pattern-match failed: induction variable init is not a "
388 "constant scalar representable as an int64: "
389 << indvar_init.ToString();
390 return nullopt;
391 }
392
393 // Check that `i` goes as `i++` in the while body.
394 //
395 // TODO(jlebar): We could also handle i-- and other idioms.
396 auto* while_body = while_op->while_body();
397 auto* while_body_indvar_update =
398 while_body->root_instruction()->operand(indvar_tuple_idx);
399 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
400 if (!Match(while_body_indvar_update,
401 m::AddAnyOrder(m::Op().Is(while_body_indvar),
402 m::ConstantEffectiveScalar(1)))) {
403 VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
404 << while_body_indvar_update->ToString();
405 return nullopt;
406 }
407
408 // Check that we do op(i, N) or op(N, i) as the while condition. Capture the
409 // value N.
410 auto* while_cond = while_op->while_condition();
411 auto* while_cond_root = while_cond->root_instruction();
412 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
413 HloInstruction* while_cond_bound = nullptr;
414 if (!Match(while_cond_root,
415 m::Op().WithBinaryOperandsAnyOrder(
416 m::Op().Is(while_cond_indvar),
417 m::ConstantEffectiveScalar(&while_cond_bound)))) {
418 VLOG(2) << "Pattern-match failed: while condition is not of the form "
419 "op(i, N) or op(N, i).";
420 return nullopt;
421 }
422 // Note: If this succeeds, the constant `N` is representable as an int64 --
423 // that is, if it's an XLA U64, it fits within an int64.
424 optional<int64> while_cond_bound_val =
425 LiteralAsScalarInt64(while_cond_bound->literal());
426 if (!while_cond_bound_val) {
427 VLOG(2) << "Pattern-match failed: while condition induction variable is "
428 "not a constant scalar representable as an int64.";
429 return nullopt;
430 }
431
432 // Handle `i = K; i < N; ++i`.
433 if (Match(while_cond_root,
434 m::Op()
435 .WithComparisonDirection(ComparisonDirection::kLt)
436 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
437 VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
438 << while_cond_root->ToString();
439 optional<int64> trips =
440 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
441 if (trips) {
442 return std::max(int64{0}, *trips);
443 } else {
444 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
445 return nullopt;
446 }
447 }
448
449 // Handle `i = K; i <= N; ++i`.
450 if (Match(while_cond_root,
451 m::Op()
452 .WithComparisonDirection(ComparisonDirection::kLe)
453 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
454 VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
455 << while_cond_root->ToString();
456 optional<int64> trips =
457 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
458 if (!trips) {
459 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
460 return nullopt;
461 }
462 trips = CheckedAdd(*trips, 1);
463 if (!trips) {
464 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
465 return nullopt;
466 }
467 return std::max<int64>(0, *trips);
468 }
469
470 VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
471 << while_cond_root->ToString();
472 return nullopt;
473 }
474
ComputeWhileLoopTripCount(HloInstruction * while_op,int64_t max_brute_force_iters)475 optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
476 int64_t max_brute_force_iters) {
477 VLOG(2) << "Getting trip count for loop " << while_op->ToString();
478
479 // The loop's induction variable is found at
480 //
481 // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
482 //
483 // where comp is while_op->while_body() or while_op->while_condition().
484 optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
485 if (!indvar_tuple_idx) {
486 return nullopt;
487 }
488
489 // Now that we know the index of the induction variable, we can we can try to
490 // compute how many times the loop executes. Start by computing the induction
491 // variable's initial value.
492 HloEvaluator evaluator(/*max_loop_iterations=*/0);
493 auto* while_init = while_op->mutable_operand(0);
494 auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
495 StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
496 if (!indvar_init_result.ok()) {
497 VLOG(2) << "Couldn't evaluate induction variable init, "
498 << indvar_init_result.status() << ", " << indvar_init->ToString();
499 return nullopt;
500 }
501 Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
502
503 // First, try to pattern-match.
504 if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
505 indvar_iter_val)) {
506 return trip_count;
507 }
508
509 // If our pattern-match failed, try brute-forcing the loop trip count.
510 auto* while_body = while_op->while_body();
511 auto* while_body_indvar_update =
512 while_body->root_instruction()->operand(*indvar_tuple_idx);
513 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
514
515 auto* while_cond = while_op->while_condition();
516 auto* while_cond_root = while_cond->root_instruction();
517 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
518
519 for (int64_t trip_count = 0; trip_count != max_brute_force_iters + 1;
520 ++trip_count) {
521 StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
522 while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
523 if (!result.ok()) {
524 VLOG(2) << "Couldn't evaluate while cond: " << result.status();
525 return nullopt;
526 }
527 if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
528 VLOG(2) << "Loop has static trip count of " << trip_count;
529 return trip_count;
530 }
531
532 // Calculate the value of the induction variable after one iteration of the
533 // loop, and check whether the while condition is true with this new value.
534 StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
535 while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
536 if (!indvar_next_result.ok()) {
537 VLOG(2) << "Couldn't evaluate induction variable update: "
538 << indvar_next_result.status();
539 return nullopt;
540 }
541 indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
542 }
543
544 VLOG(2) << "Loop has unknown trip count.";
545 return nullopt;
546 }
547
548 // If the only user of this instruction is a get-tuple-element, return that
549 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
550 // get a false negative if there are several copies of the same GTE, or there
551 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)552 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
553 if (inst->user_count() != 1) {
554 return nullptr;
555 }
556
557 HloInstruction* user = inst->users().back();
558 if (user->opcode() != HloOpcode::kGetTupleElement) {
559 return nullptr;
560 }
561 return user;
562 }
563
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)564 optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
565 // If we know the exact trip count, it's also the upper bound.
566 auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
567 if (exact_trip_count) {
568 VLOG(2) << "Loop has exact trip count.";
569 return exact_trip_count;
570 }
571
572 // There is one more case we know how to handle. If the loop condition only
573 // looks at one element of the tuple, and the loop body sets this element to a
574 // constant, there are two options:
575 // 1) Evaluating the condition on this constant returns true. In this case,
576 // the loop either executes 0 times, or is an infinite loop, depending on the
577 // init value.
578 // 2) Evaluating the condition on this constant returns false. In this case,
579 // the loop executes 0 or 1 times, depending on the init value. This means
580 // that, regardless of the init value, the upper bound on the trip count is 1.
581
582 // Check whether the condition depends on a single parameter, and find out
583 // which.
584 auto* while_cond = while_op->while_condition();
585 auto* while_cond_param = while_cond->parameter_instruction(0);
586 auto* cond_gte = GetOnlyGTE(while_cond_param);
587 if (!cond_gte) {
588 VLOG(2) << "Induction variable not found in loop condition: "
589 << while_cond->root_instruction()->ToString();
590 return nullopt;
591 }
592
593 // Now check whether this gets set to a constant by the while body.
594 auto* while_body = while_op->while_body();
595 auto* while_body_root = while_body->root_instruction();
596 if (while_body_root->opcode() != HloOpcode::kTuple) {
597 VLOG(3) << "While body's root is not a tuple instruction: "
598 << while_body_root->ToString();
599 return nullopt;
600 }
601
602 int64_t indvar_index = cond_gte->tuple_index();
603 auto* while_body_indvar = while_body_root->operand(indvar_index);
604 if (while_body_indvar->opcode() != HloOpcode::kConstant) {
605 VLOG(3) << "While body does not set the IV to a constant: "
606 << while_body_indvar->ToString();
607 return nullopt;
608 }
609
610 // We have a constant. Evaluate the condition on this constant.
611 HloEvaluator evaluator(/*max_loop_iterations=*/0);
612 Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
613 TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
614 /*dest_shape_index=*/{indvar_index},
615 /*src_shape_index=*/{}));
616 StatusOr<Literal> eval_result =
617 evaluator.Evaluate(*while_cond, {std::move(fake_input)});
618
619 if (!eval_result.ok()) {
620 VLOG(2) << "Couldn't evaluate while loop condition.";
621 return nullopt;
622 }
623
624 Literal cond_result_pred = std::move(eval_result.ValueOrDie());
625 CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
626 ShapeUtil::MakeShape(PRED, {})));
627
628 // Per the explanation above, if the evaluated condition returns false, the
629 // loop executes at most once.
630 bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
631 if (!cond_returns_true) {
632 VLOG(2) << "Upper bound on the trip count is 1";
633 return 1;
634 }
635
636 VLOG(2) << "Loop has no known upper bound on the trip count.";
637 return nullopt;
638 }
639
640 } // namespace xla
641