1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
17
18 #include "absl/base/casts.h"
19 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25
26 namespace xla {
27
28 using absl::nullopt;
29 using absl::optional;
30 namespace m = match;
31
32 // Finds and returns the non-constant operand in instr.
33 //
34 // CHECK-fails if instr doesn't have exactly one unique non-constant operand.
NonConstantOperand(const HloInstruction * instr)35 static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
36 const HloInstruction* result = nullptr;
37 for (const HloInstruction* operand : instr->operands()) {
38 if (!operand->IsConstant()) {
39 if (result != nullptr) {
40 CHECK_EQ(result, operand);
41 }
42 result = operand;
43 }
44 }
45 CHECK_NE(result, nullptr);
46 return result;
47 }
48
49 // If all of instr's operands are either constants or have the form
50 // get-tuple-element(gte_operand, N)
51 // for the same value N, returns N. Otherwise, returns nullopt.
GetGTEOperandIndex(const HloInstruction * instr,const HloInstruction * gte_operand)52 static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
53 const HloInstruction* gte_operand) {
54 VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", "
55 << gte_operand->ToString() << ")";
56
57 // Among the operands of `instr`, find one that is a get-tuple-element op.
58 auto gte_it = c_find_if(instr->operands(), [](const HloInstruction* instr) {
59 return instr->opcode() == HloOpcode::kGetTupleElement;
60 });
61 if (gte_it == instr->operands().end()) {
62 VLOG(2) << "instr does not have a gte operand.";
63 return nullopt;
64 }
65
66 // All operands of `instr` must be either constants or of the form
67 // get-tuple-element(gte_operand, tuple_idx)
68 // for the same value tuple_idx.
69 int64 tuple_idx = (*gte_it)->tuple_index();
70 for (const HloInstruction* operand : instr->operands()) {
71 if (!Match(operand, m::Constant()) &&
72 !Match(operand,
73 m::GetTupleElement(m::Op().Is(gte_operand), tuple_idx))) {
74 VLOG(2)
75 << "instr uses something other than a constant or gte(gte_operand, "
76 << tuple_idx << "): " << operand->ToString();
77 return nullopt;
78 }
79 }
80 return tuple_idx;
81 }
82
83 // The below function identifies a subset of all possible auxiliary
84 // induction variables (AIV). Specifically, candidates are gtes, e.g.,
85 // gte(param0, N)
86 // The function checks if the loop body plumbs the AIV
87 // through the same tuple index at root, and that ops involving AIV
88 // involve constants.
89 // op2 = op(constants, gte(param0, N), constants)
90 // op3 = op(constants, f(op2, gte(param0, N), constants)
91 // op4 = op(constants, f(op3, constants)
92 // root = tuple(..., op4, ...)
93 // Further, the ops are restricted to basic math ops (+,-,*,/).
94 // Finally, loop invariant GTEs are excluded from AIVs.
95 // We can expand the ops category/nature of AIVs as needed.
GetAuxiliaryLoopInductionVars(const HloInstruction * while_op)96 std::vector<const HloInstruction*> GetAuxiliaryLoopInductionVars(
97 const HloInstruction* while_op) {
98 std::vector<const HloInstruction*> aux_ind_gte;
99 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
100 auto* while_body = while_op->while_body();
101 auto* while_body_param = while_body->parameter_instruction(0);
102 VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString();
103 VLOG(2) << "the parameter instr:" << while_body_param->ToShortString();
104 VLOG(2) << "the parameter user count:" << while_body_param->users().size();
105 if (while_body_param == nullptr) return aux_ind_gte;
106
107 // candidates_pairs = pair<inst, inst>(
108 // operands of the root while body,
109 // GTE only operands that index into the same position in the parameter)
110 // for each candidate_pair (x, y)
111 // find all paths between x and y,
112 // each paths should satisfy the above listed criterion
113 // index that x and y used is added as a aux variable index
114 std::map<int64, const HloInstruction*> extractions;
115 for (const HloInstruction* indx_instr : while_body_param->users()) {
116 if (indx_instr->opcode() != HloOpcode::kGetTupleElement) {
117 continue;
118 }
119 auto it = extractions.find(indx_instr->tuple_index());
120 // if we find two extractions at the same index, we ignore such
121 // a candidate
122 if (it != extractions.end()) {
123 it->second = nullptr;
124 VLOG(2) << "two extractions at same index:" << indx_instr->ToString();
125 } else {
126 extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr));
127 VLOG(2) << "inserting extraction :" << indx_instr->ToString();
128 }
129 }
130 VLOG(2) << "total extractions size:" << extractions.size() << std::endl;
131 if (extractions.empty()) {
132 return aux_ind_gte;
133 }
134
135 auto* while_body_root = while_body->root_instruction();
136 if (while_body_root->opcode() != HloOpcode::kTuple) {
137 VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString();
138 return aux_ind_gte;
139 }
140 int64 index = -1;
141 std::map<int64, const HloInstruction*> insertions;
142 for (const HloInstruction* operand : while_body_root->operands()) {
143 index++;
144 if (!operand->IsConstant()) {
145 auto it = insertions.find(index);
146 if (it != insertions.end()) {
147 it->second = nullptr;
148 VLOG(2) << "two insertions at same index:" << operand->ToString();
149 } else {
150 insertions.insert(std::make_pair(index, operand));
151 VLOG(2) << "inserting insertions:" << operand->ToString();
152 }
153 }
154 }
155 if (insertions.empty()) {
156 return aux_ind_gte;
157 }
158
159 std::map<int64, std::pair<const HloInstruction*, const HloInstruction*>>
160 candidate_pairs;
161 for (; index >= 0; --index) {
162 const HloInstruction *ext, *inst;
163 ext = (extractions.find(index) != extractions.end())
164 ? extractions.find(index)->second
165 : nullptr;
166 inst = (insertions.find(index) != insertions.end())
167 ? insertions.find(index)->second
168 : nullptr;
169 if (ext != nullptr && inst != nullptr) {
170 // Filter out trivial aux, i.e., extract directly to an insert.
171 if (ext != inst) {
172 candidate_pairs.insert(
173 std::make_pair(index, std::make_pair(ext, inst)));
174 }
175 }
176 }
177 VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl;
178
179 // Passed to ReachabilityMap to decide the type of produce-consumer edges
180 // along the reachability path.
181 const auto add_dependencies = [](const HloInstruction* hlo,
182 std::vector<HloInstruction*>* inputs) {
183 HloInstruction* non_const_operand = nullptr;
184 int num_non_constants = 0;
185 for (HloInstruction* operand : hlo->operands()) {
186 if (!operand->IsConstant()) {
187 num_non_constants++;
188 non_const_operand = operand;
189 }
190 }
191 if (num_non_constants == 1 &&
192 (hlo->opcode() == HloOpcode::kGetTupleElement ||
193 hlo->opcode() == HloOpcode::kAdd ||
194 hlo->opcode() == HloOpcode::kMultiply ||
195 hlo->opcode() == HloOpcode::kDivide ||
196 hlo->opcode() == HloOpcode::kSubtract)) {
197 inputs->push_back(non_const_operand);
198 }
199 };
200
201 std::unique_ptr<HloReachabilityMap> hrm =
202 HloReachabilityMap::BuildWithRestrictions(
203 while_body,
204 absl::FunctionRef<void(const HloInstruction* hlo,
205 std::vector<HloInstruction*>* inputs)>(
206 add_dependencies));
207
208 for (auto candidates : candidate_pairs) {
209 VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString()
210 << "*************" << (candidates.second.second)->ToString()
211 << std::endl;
212 if (hrm->IsReachable(candidates.second.first, candidates.second.second)) {
213 aux_ind_gte.push_back(candidates.second.first);
214 VLOG(2) << "YES";
215 } else {
216 VLOG(2) << "NO";
217 }
218 }
219 VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size();
220 return aux_ind_gte;
221 }
222
223 // Tries to get the tuple index of the induction variable of a while loop.
224 //
225 // Checks that the loop condition and body both plumb the induction variable
226 // through the same tuple index, and that they both apply exactly one op to the
227 // induction variable before deciding whether to do another loop iteration (in
228 // the loop condition's case) or packing the induction variable into the result
229 // tuple (in the loop body's case).
230 //
231 // Specifically, checks that the loop condition has structure
232 //
233 // root = op(constants, get-tuple-elem(param0, N), constants)
234 //
235 // and the loop body has the structure
236 //
237 // inc = op(constants, get-tuple-elem(param0, N), constants)
238 // root = tuple(..., inc, ...) // inc is N'th operand of tuple().
239 //
240 // If so, returns N. Otherwise, returns nullopt.
GetLoopInductionVarTupleIdx(const HloInstruction * while_op)241 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
242 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
243 VLOG(2) << "Finding induction variable for loop "
244 << while_op->ToShortString();
245
246 // The while_cond computation should have the form
247 //
248 // while_cond_root =
249 // op(constants, get-tuple-elem(while_cond_param, N), constants).
250 //
251 // If it does, set indvar_tuple_idx to N.
252 auto* while_cond = while_op->while_condition();
253 auto* while_cond_root = while_cond->root_instruction();
254 auto* while_cond_param = while_cond->parameter_instruction(0);
255 optional<int64> indvar_tuple_idx =
256 GetGTEOperandIndex(while_cond_root, while_cond_param);
257 if (!indvar_tuple_idx) {
258 VLOG(2) << "Induction variable not found in loop condition: "
259 << while_cond->root_instruction()->ToString();
260 return nullopt;
261 }
262
263 // The while_body computation should have the form
264 //
265 // while_body_inc =
266 // op(constants, get-tuple-elem(while_body_param, N), constants)
267 // while_body_root = tuple(..., while_body_inc, ...)
268 //
269 // where while_body_inc is operand N of while_body_root.
270 auto* while_body = while_op->while_body();
271 auto* while_body_root = while_body->root_instruction();
272 if (while_body_root->opcode() != HloOpcode::kTuple) {
273 VLOG(2) << "While body's root is not a tuple instruction: "
274 << while_body_root->ToString();
275 return nullopt;
276 }
277
278 auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx);
279 auto* while_body_param = while_body->parameter_instruction(0);
280 optional<int64> while_body_indvar_tuple_idx =
281 GetGTEOperandIndex(while_body_inc, while_body_param);
282 if (!while_body_indvar_tuple_idx) {
283 VLOG(2)
284 << "Induction variable not found in while body increment instruction: "
285 << while_body_inc->ToString();
286 return nullopt;
287 }
288 if (while_body_indvar_tuple_idx != indvar_tuple_idx) {
289 VLOG(2) << "Tuple index of induction variable does not match between loop "
290 "condition ("
291 << *indvar_tuple_idx << ") and while body ("
292 << *while_body_indvar_tuple_idx << ")";
293 return nullopt;
294 }
295
296 // Finally, check that the while loop's initial value is a tuple with enough
297 // elements.
298 auto* while_init = while_op->operand(0);
299 if (while_init->opcode() != HloOpcode::kTuple) {
300 VLOG(2) << "While init expected to be a tuple: " << while_init->ToString();
301 return nullopt;
302 }
303
304 VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx;
305 return indvar_tuple_idx;
306 }
307
308 // Converts the given literal to a scalar int64, if possible.
309 //
310 // Fails if the literal is not an integral type or if the value it contains
311 // cannot be represented in an int64.
LiteralAsScalarInt64(const Literal & l)312 static optional<int64> LiteralAsScalarInt64(const Literal& l) {
313 if (!ShapeUtil::IsEffectiveScalar(l.shape())) {
314 VLOG(2) << "literal is not an effective scalar: " << l.ToString();
315 return nullopt;
316 }
317 switch (l.shape().element_type()) {
318 case S8:
319 return l.GetFirstElement<int8>();
320 case S16:
321 return l.GetFirstElement<int16>();
322 case S32:
323 return l.GetFirstElement<int32>();
324 case S64:
325 return l.GetFirstElement<int64>();
326 case U8:
327 return l.GetFirstElement<uint8>();
328 case U16:
329 return l.GetFirstElement<uint16>();
330 case U32:
331 return l.GetFirstElement<uint32>();
332 case U64: {
333 uint64 v = l.GetFirstElement<uint64>();
334 if (v > static_cast<uint64>(std::numeric_limits<int64>::max())) {
335 VLOG(2) << "uint64 literal is out of range for int64: " << v;
336 return nullopt;
337 }
338 return v;
339 }
340 default:
341 VLOG(2) << "literal is of non-integral type " << l.shape().ToString();
342 return nullopt;
343 }
344 }
345
346 // Computes a + b, returning nullopt if it overflows.
CheckedAdd(int64 a,int64 b)347 optional<int64> CheckedAdd(int64 a, int64 b) {
348 // Overflow occurred iff `a` and `b` have the same sign and `a + b` has a
349 // different sign, see Hacker's Delignt 2nd Ed. pp 28.
350 uint64 aa = absl::bit_cast<uint64>(a);
351 uint64 bb = absl::bit_cast<uint64>(b);
352 int64 result = absl::bit_cast<int64>(aa + bb);
353 if (a >= 0 == b >= 0 && result >= 0 != a >= 0) {
354 return nullopt;
355 }
356 return result;
357 }
358
359 // Computes a - b, returning nullopt if it overflows.
CheckedSubtract(int64 a,int64 b)360 optional<int64> CheckedSubtract(int64 a, int64 b) {
361 uint64 aa = absl::bit_cast<uint64>(a);
362 uint64 bb = absl::bit_cast<uint64>(b);
363 int64 result = absl::bit_cast<int64>(aa - bb);
364 // Overflow occurred iff `a` and `b` have different signs and the sign of
365 // `a - b` is the same as that of `b`, see Hacker's Delight 2nd Ed. pp 29.
366 if (a >= 0 != b >= 0 && result >= 0 == b >= 0) {
367 return nullopt;
368 }
369 return result;
370 }
371
372 // Check if
373 // - `i` is initialized to a scalar constant K (namely, `indvar_init`),
374 // - the while condition does `i < N` or `i <= N`, and
375 // - the while body does `i++`.
376 // If so, it's trivial to compute the loop bound.
PatternMatchLoopTripCount(HloInstruction * while_op,int64 indvar_tuple_idx,const Literal & indvar_init)377 static optional<int64> PatternMatchLoopTripCount(HloInstruction* while_op,
378 int64 indvar_tuple_idx,
379 const Literal& indvar_init) {
380 // First, find the scalar constant K that `i` is initialized to.
381 optional<int64> indvar_init_val = LiteralAsScalarInt64(indvar_init);
382 if (!indvar_init_val) {
383 VLOG(2) << "Pattern-match failed: induction variable init is not a "
384 "constant scalar representable as an int64: "
385 << indvar_init.ToString();
386 return nullopt;
387 }
388
389 // Check that `i` goes as `i++` in the while body.
390 //
391 // TODO(jlebar): We could also handle i-- and other idioms.
392 auto* while_body = while_op->while_body();
393 auto* while_body_indvar_update =
394 while_body->root_instruction()->operand(indvar_tuple_idx);
395 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
396 if (!Match(while_body_indvar_update,
397 m::AddAnyOrder(m::Op().Is(while_body_indvar),
398 m::ConstantEffectiveScalar(1)))) {
399 VLOG(2) << "Pattern-match failed: induction variable does not go as i++: "
400 << while_body_indvar_update->ToString();
401 return nullopt;
402 }
403
404 // Check that we do op(i, N) or op(N, i) as the while condition. Capture the
405 // value N.
406 auto* while_cond = while_op->while_condition();
407 auto* while_cond_root = while_cond->root_instruction();
408 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
409 HloInstruction* while_cond_bound = nullptr;
410 if (!Match(while_cond_root,
411 m::Op().WithBinaryOperandsAnyOrder(
412 m::Op().Is(while_cond_indvar),
413 m::ConstantEffectiveScalar(&while_cond_bound)))) {
414 VLOG(2) << "Pattern-match failed: while condition is not of the form "
415 "op(i, N) or op(N, i).";
416 return nullopt;
417 }
418 // Note: If this succeeds, the constant `N` is representable as an int64 --
419 // that is, if it's an XLA U64, it fits within an int64.
420 optional<int64> while_cond_bound_val =
421 LiteralAsScalarInt64(while_cond_bound->literal());
422 if (!while_cond_bound_val) {
423 VLOG(2) << "Pattern-match failed: while condition induction variable is "
424 "not a constant scalar representable as an int64.";
425 return nullopt;
426 }
427
428 // Handle `i = K; i < N; ++i`.
429 if (Match(while_cond_root,
430 m::Op()
431 .WithComparisonDirection(ComparisonDirection::kLt)
432 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
433 VLOG(2) << "Pattern-match succeeded: loop condition is i < N: "
434 << while_cond_root->ToString();
435 optional<int64> trips =
436 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
437 if (trips) {
438 return std::max(int64{0}, *trips);
439 } else {
440 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX.";
441 return nullopt;
442 }
443 }
444
445 // Handle `i = K; i <= N; ++i`.
446 if (Match(while_cond_root,
447 m::Op()
448 .WithComparisonDirection(ComparisonDirection::kLe)
449 .WithOperand(0, m::Op().Is(while_cond_indvar)))) {
450 VLOG(2) << "Pattern-match succeeded: loop condition is i <= N: "
451 << while_cond_root->ToString();
452 optional<int64> trips =
453 CheckedSubtract(*while_cond_bound_val, *indvar_init_val);
454 if (!trips) {
455 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
456 return nullopt;
457 }
458 trips = CheckedAdd(*trips, 1);
459 if (!trips) {
460 VLOG(2) << "Pattern-match failed: Trip count exceeds INT64_MAX";
461 return nullopt;
462 }
463 return std::max<int64>(0, *trips);
464 }
465
466 VLOG(2) << "Pattern-match failed: while condition follows unknown pattern: "
467 << while_cond_root->ToString();
468 return nullopt;
469 }
470
ComputeWhileLoopTripCount(HloInstruction * while_op,int64 max_brute_force_iters)471 optional<int64> ComputeWhileLoopTripCount(HloInstruction* while_op,
472 int64 max_brute_force_iters) {
473 VLOG(2) << "Getting trip count for loop " << while_op->ToString();
474
475 // The loop's induction variable is found at
476 //
477 // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx),
478 //
479 // where comp is while_op->while_body() or while_op->while_condition().
480 optional<int64> indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op);
481 if (!indvar_tuple_idx) {
482 return nullopt;
483 }
484
485 // Now that we know the index of the induction variable, we can we can try to
486 // compute how many times the loop executes. Start by computing the induction
487 // variable's initial value.
488 HloEvaluator evaluator(/*max_loop_iterations=*/0);
489 auto* while_init = while_op->mutable_operand(0);
490 auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
491 StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
492 if (!indvar_init_result.ok()) {
493 VLOG(2) << "Couldn't evaluate induction variable init, "
494 << indvar_init_result.status() << ", " << indvar_init->ToString();
495 return nullopt;
496 }
497 Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
498
499 // First, try to pattern-match.
500 if (auto trip_count = PatternMatchLoopTripCount(while_op, *indvar_tuple_idx,
501 indvar_iter_val)) {
502 return trip_count;
503 }
504
505 // If our pattern-match failed, try brute-forcing the loop trip count.
506 auto* while_body = while_op->while_body();
507 auto* while_body_indvar_update =
508 while_body->root_instruction()->operand(*indvar_tuple_idx);
509 auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
510
511 auto* while_cond = while_op->while_condition();
512 auto* while_cond_root = while_cond->root_instruction();
513 auto* while_cond_indvar = NonConstantOperand(while_cond_root);
514
515 for (int64 trip_count = 0; trip_count != max_brute_force_iters + 1;
516 ++trip_count) {
517 StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
518 while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
519 if (!result.ok()) {
520 VLOG(2) << "Couldn't evaluate while cond: " << result.status();
521 return nullopt;
522 }
523 if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
524 VLOG(2) << "Loop has static trip count of " << trip_count;
525 return trip_count;
526 }
527
528 // Calculate the value of the induction variable after one iteration of the
529 // loop, and check whether the while condition is true with this new value.
530 StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
531 while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
532 if (!indvar_next_result.ok()) {
533 VLOG(2) << "Couldn't evaluate induction variable update: "
534 << indvar_next_result.status();
535 return nullopt;
536 }
537 indvar_iter_val = std::move(indvar_next_result).ValueOrDie();
538 }
539
540 VLOG(2) << "Loop has unknown trip count.";
541 return nullopt;
542 }
543
544 // If the only user of this instruction is a get-tuple-element, return that
545 // get-tuple-element, otherwise return null. If this runs before CSE/DCE, we may
546 // get a false negative if there are several copies of the same GTE, or there
547 // are unused GTEs, but we can live with this.
GetOnlyGTE(HloInstruction * inst)548 static HloInstruction* GetOnlyGTE(HloInstruction* inst) {
549 if (inst->user_count() != 1) {
550 return nullptr;
551 }
552
553 HloInstruction* user = inst->users().back();
554 if (user->opcode() != HloOpcode::kGetTupleElement) {
555 return nullptr;
556 }
557 return user;
558 }
559
ComputeWhileLoopTripCountUpperBound(HloInstruction * while_op)560 optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
561 // If we know the exact trip count, it's also the upper bound.
562 auto exact_trip_count = ComputeWhileLoopTripCount(while_op);
563 if (exact_trip_count) {
564 VLOG(2) << "Loop has exact trip count.";
565 return exact_trip_count;
566 }
567
568 // There is one more case we know how to handle. If the loop condition only
569 // looks at one element of the tuple, and the loop body sets this element to a
570 // constant, there are two options:
571 // 1) Evaluating the condition on this constant returns true. In this case,
572 // the loop either executes 0 times, or is an infinite loop, depending on the
573 // init value.
574 // 2) Evaluating the condition on this constant returns false. In this case,
575 // the loop executes 0 or 1 times, depending on the init value. This means
576 // that, regardless of the init value, the upper bound on the trip count is 1.
577
578 // Check whether the condition depends on a single parameter, and find out
579 // which.
580 auto* while_cond = while_op->while_condition();
581 auto* while_cond_param = while_cond->parameter_instruction(0);
582 auto* cond_gte = GetOnlyGTE(while_cond_param);
583 if (!cond_gte) {
584 VLOG(2) << "Induction variable not found in loop condition: "
585 << while_cond->root_instruction()->ToString();
586 return nullopt;
587 }
588
589 // Now check whether this gets set to a constant by the while body.
590 auto* while_body = while_op->while_body();
591 auto* while_body_root = while_body->root_instruction();
592 if (while_body_root->opcode() != HloOpcode::kTuple) {
593 VLOG(3) << "While body's root is not a tuple instruction: "
594 << while_body_root->ToString();
595 return nullopt;
596 }
597
598 int64 indvar_index = cond_gte->tuple_index();
599 auto* while_body_indvar = while_body_root->operand(indvar_index);
600 if (while_body_indvar->opcode() != HloOpcode::kConstant) {
601 VLOG(3) << "While body does not set the IV to a constant: "
602 << while_body_indvar->ToString();
603 return nullopt;
604 }
605
606 // We have a constant. Evaluate the condition on this constant.
607 HloEvaluator evaluator(/*max_loop_iterations=*/0);
608 Literal fake_input = Literal::CreateFromShape(while_cond_param->shape());
609 TF_CHECK_OK(fake_input.CopyFrom(while_body_indvar->literal(),
610 /*dest_shape_index=*/{indvar_index},
611 /*src_shape_index=*/{}));
612 StatusOr<Literal> eval_result =
613 evaluator.Evaluate(*while_cond, {std::move(fake_input)});
614
615 if (!eval_result.ok()) {
616 VLOG(2) << "Couldn't evaluate while loop condition.";
617 return nullopt;
618 }
619
620 Literal cond_result_pred = std::move(eval_result.ValueOrDie());
621 CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
622 ShapeUtil::MakeShape(PRED, {})));
623
624 // Per the explanation above, if the evaluated condition returns false, the
625 // loop executes at most once.
626 bool cond_returns_true = cond_result_pred.GetFirstElement<bool>();
627 if (!cond_returns_true) {
628 VLOG(2) << "Upper bound on the trip count is 1";
629 return 1;
630 }
631
632 VLOG(2) << "Loop has no known upper bound on the trip count.";
633 return nullopt;
634 }
635
636 } // namespace xla
637