• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/loop_descriptor.h"
16 
17 #include <algorithm>
18 #include <iostream>
19 #include <limits>
20 #include <stack>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24 
25 #include "source/opt/cfg.h"
26 #include "source/opt/constants.h"
27 #include "source/opt/dominator_tree.h"
28 #include "source/opt/ir_builder.h"
29 #include "source/opt/ir_context.h"
30 #include "source/opt/iterator.h"
31 #include "source/opt/tree_iterator.h"
32 #include "source/util/make_unique.h"
33 
34 namespace spvtools {
35 namespace opt {
36 
37 // Takes in a phi instruction |induction| and the loop |header| and returns the
38 // step operation of the loop.
GetInductionStepOperation(const Instruction * induction) const39 Instruction* Loop::GetInductionStepOperation(
40     const Instruction* induction) const {
41   // Induction must be a phi instruction.
42   assert(induction->opcode() == SpvOpPhi);
43 
44   Instruction* step = nullptr;
45 
46   analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
47 
48   // Traverse the incoming operands of the phi instruction.
49   for (uint32_t operand_id = 1; operand_id < induction->NumInOperands();
50        operand_id += 2) {
51     // Incoming edge.
52     BasicBlock* incoming_block =
53         context_->cfg()->block(induction->GetSingleWordInOperand(operand_id));
54 
55     // Check if the block is dominated by header, and thus coming from within
56     // the loop.
57     if (IsInsideLoop(incoming_block)) {
58       step = def_use_manager->GetDef(
59           induction->GetSingleWordInOperand(operand_id - 1));
60       break;
61     }
62   }
63 
64   if (!step || !IsSupportedStepOp(step->opcode())) {
65     return nullptr;
66   }
67 
68   // The induction variable which binds the loop must only be modified once.
69   uint32_t lhs = step->GetSingleWordInOperand(0);
70   uint32_t rhs = step->GetSingleWordInOperand(1);
71 
72   // One of the left hand side or right hand side of the step instruction must
73   // be the induction phi and the other must be an OpConstant.
74   if (lhs != induction->result_id() && rhs != induction->result_id()) {
75     return nullptr;
76   }
77 
78   if (def_use_manager->GetDef(lhs)->opcode() != SpvOp::SpvOpConstant &&
79       def_use_manager->GetDef(rhs)->opcode() != SpvOp::SpvOpConstant) {
80     return nullptr;
81   }
82 
83   return step;
84 }
85 
86 // Returns true if the |step| operation is an induction variable step operation
87 // which is currently handled.
IsSupportedStepOp(SpvOp step) const88 bool Loop::IsSupportedStepOp(SpvOp step) const {
89   switch (step) {
90     case SpvOp::SpvOpISub:
91     case SpvOp::SpvOpIAdd:
92       return true;
93     default:
94       return false;
95   }
96 }
97 
IsSupportedCondition(SpvOp condition) const98 bool Loop::IsSupportedCondition(SpvOp condition) const {
99   switch (condition) {
100     // <
101     case SpvOp::SpvOpULessThan:
102     case SpvOp::SpvOpSLessThan:
103     // >
104     case SpvOp::SpvOpUGreaterThan:
105     case SpvOp::SpvOpSGreaterThan:
106 
107     // >=
108     case SpvOp::SpvOpSGreaterThanEqual:
109     case SpvOp::SpvOpUGreaterThanEqual:
110     // <=
111     case SpvOp::SpvOpSLessThanEqual:
112     case SpvOp::SpvOpULessThanEqual:
113 
114       return true;
115     default:
116       return false;
117   }
118 }
119 
GetResidualConditionValue(SpvOp condition,int64_t initial_value,int64_t step_value,size_t number_of_iterations,size_t factor)120 int64_t Loop::GetResidualConditionValue(SpvOp condition, int64_t initial_value,
121                                         int64_t step_value,
122                                         size_t number_of_iterations,
123                                         size_t factor) {
124   int64_t remainder =
125       initial_value + (number_of_iterations % factor) * step_value;
126 
127   // We subtract or add one as the above formula calculates the remainder if the
128   // loop where just less than or greater than. Adding or subtracting one should
129   // give a functionally equivalent value.
130   switch (condition) {
131     case SpvOp::SpvOpSGreaterThanEqual:
132     case SpvOp::SpvOpUGreaterThanEqual: {
133       remainder -= 1;
134       break;
135     }
136     case SpvOp::SpvOpSLessThanEqual:
137     case SpvOp::SpvOpULessThanEqual: {
138       remainder += 1;
139       break;
140     }
141 
142     default:
143       break;
144   }
145   return remainder;
146 }
147 
GetConditionInst() const148 Instruction* Loop::GetConditionInst() const {
149   BasicBlock* condition_block = FindConditionBlock();
150   if (!condition_block) {
151     return nullptr;
152   }
153   Instruction* branch_conditional = &*condition_block->tail();
154   if (!branch_conditional ||
155       branch_conditional->opcode() != SpvOpBranchConditional) {
156     return nullptr;
157   }
158   Instruction* condition_inst = context_->get_def_use_mgr()->GetDef(
159       branch_conditional->GetSingleWordInOperand(0));
160   if (IsSupportedCondition(condition_inst->opcode())) {
161     return condition_inst;
162   }
163 
164   return nullptr;
165 }
166 
167 // Extract the initial value from the |induction| OpPhi instruction and store it
168 // in |value|. If the function couldn't find the initial value of |induction|
169 // return false.
GetInductionInitValue(const Instruction * induction,int64_t * value) const170 bool Loop::GetInductionInitValue(const Instruction* induction,
171                                  int64_t* value) const {
172   Instruction* constant_instruction = nullptr;
173   analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
174 
175   for (uint32_t operand_id = 0; operand_id < induction->NumInOperands();
176        operand_id += 2) {
177     BasicBlock* bb = context_->cfg()->block(
178         induction->GetSingleWordInOperand(operand_id + 1));
179 
180     if (!IsInsideLoop(bb)) {
181       constant_instruction = def_use_manager->GetDef(
182           induction->GetSingleWordInOperand(operand_id));
183     }
184   }
185 
186   if (!constant_instruction) return false;
187 
188   const analysis::Constant* constant =
189       context_->get_constant_mgr()->FindDeclaredConstant(
190           constant_instruction->result_id());
191   if (!constant) return false;
192 
193   if (value) {
194     const analysis::Integer* type =
195         constant->AsIntConstant()->type()->AsInteger();
196 
197     if (type->IsSigned()) {
198       *value = constant->AsIntConstant()->GetS32BitValue();
199     } else {
200       *value = constant->AsIntConstant()->GetU32BitValue();
201     }
202   }
203 
204   return true;
205 }
206 
Loop(IRContext * context,DominatorAnalysis * dom_analysis,BasicBlock * header,BasicBlock * continue_target,BasicBlock * merge_target)207 Loop::Loop(IRContext* context, DominatorAnalysis* dom_analysis,
208            BasicBlock* header, BasicBlock* continue_target,
209            BasicBlock* merge_target)
210     : context_(context),
211       loop_header_(header),
212       loop_continue_(continue_target),
213       loop_merge_(merge_target),
214       loop_preheader_(nullptr),
215       parent_(nullptr),
216       loop_is_marked_for_removal_(false) {
217   assert(context);
218   assert(dom_analysis);
219   loop_preheader_ = FindLoopPreheader(dom_analysis);
220   loop_latch_ = FindLatchBlock();
221 }
222 
FindLoopPreheader(DominatorAnalysis * dom_analysis)223 BasicBlock* Loop::FindLoopPreheader(DominatorAnalysis* dom_analysis) {
224   CFG* cfg = context_->cfg();
225   DominatorTree& dom_tree = dom_analysis->GetDomTree();
226   DominatorTreeNode* header_node = dom_tree.GetTreeNode(loop_header_);
227 
228   // The loop predecessor.
229   BasicBlock* loop_pred = nullptr;
230 
231   auto header_pred = cfg->preds(loop_header_->id());
232   for (uint32_t p_id : header_pred) {
233     DominatorTreeNode* node = dom_tree.GetTreeNode(p_id);
234     if (node && !dom_tree.Dominates(header_node, node)) {
235       // The predecessor is not part of the loop, so potential loop preheader.
236       if (loop_pred && node->bb_ != loop_pred) {
237         // If we saw 2 distinct predecessors that are outside the loop, we don't
238         // have a loop preheader.
239         return nullptr;
240       }
241       loop_pred = node->bb_;
242     }
243   }
244   // Safe guard against invalid code, SPIR-V spec forbids loop with the entry
245   // node as header.
246   assert(loop_pred && "The header node is the entry block ?");
247 
248   // So we have a unique basic block that can enter this loop.
249   // If this loop is the unique successor of this block, then it is a loop
250   // preheader.
251   bool is_preheader = true;
252   uint32_t loop_header_id = loop_header_->id();
253   const auto* const_loop_pred = loop_pred;
254   const_loop_pred->ForEachSuccessorLabel(
255       [&is_preheader, loop_header_id](const uint32_t id) {
256         if (id != loop_header_id) is_preheader = false;
257       });
258   if (is_preheader) return loop_pred;
259   return nullptr;
260 }
261 
IsInsideLoop(Instruction * inst) const262 bool Loop::IsInsideLoop(Instruction* inst) const {
263   const BasicBlock* parent_block = context_->get_instr_block(inst);
264   if (!parent_block) return false;
265   return IsInsideLoop(parent_block);
266 }
267 
IsBasicBlockInLoopSlow(const BasicBlock * bb)268 bool Loop::IsBasicBlockInLoopSlow(const BasicBlock* bb) {
269   assert(bb->GetParent() && "The basic block does not belong to a function");
270   DominatorAnalysis* dom_analysis =
271       context_->GetDominatorAnalysis(bb->GetParent());
272   if (dom_analysis->IsReachable(bb) &&
273       !dom_analysis->Dominates(GetHeaderBlock(), bb))
274     return false;
275 
276   return true;
277 }
278 
GetOrCreatePreHeaderBlock()279 BasicBlock* Loop::GetOrCreatePreHeaderBlock() {
280   if (loop_preheader_) return loop_preheader_;
281 
282   CFG* cfg = context_->cfg();
283   loop_header_ = cfg->SplitLoopHeader(loop_header_);
284   return loop_preheader_;
285 }
286 
SetContinueBlock(BasicBlock * continue_block)287 void Loop::SetContinueBlock(BasicBlock* continue_block) {
288   assert(IsInsideLoop(continue_block));
289   loop_continue_ = continue_block;
290 }
291 
SetLatchBlock(BasicBlock * latch)292 void Loop::SetLatchBlock(BasicBlock* latch) {
293 #ifndef NDEBUG
294   assert(latch->GetParent() && "The basic block does not belong to a function");
295 
296   const auto* const_latch = latch;
297   const_latch->ForEachSuccessorLabel([this](uint32_t id) {
298     assert((!IsInsideLoop(id) || id == GetHeaderBlock()->id()) &&
299            "A predecessor of the continue block does not belong to the loop");
300   });
301 #endif  // NDEBUG
302   assert(IsInsideLoop(latch) && "The continue block is not in the loop");
303 
304   SetLatchBlockImpl(latch);
305 }
306 
SetMergeBlock(BasicBlock * merge)307 void Loop::SetMergeBlock(BasicBlock* merge) {
308 #ifndef NDEBUG
309   assert(merge->GetParent() && "The basic block does not belong to a function");
310 #endif  // NDEBUG
311   assert(!IsInsideLoop(merge) && "The merge block is in the loop");
312 
313   SetMergeBlockImpl(merge);
314   if (GetHeaderBlock()->GetLoopMergeInst()) {
315     UpdateLoopMergeInst();
316   }
317 }
318 
SetPreHeaderBlock(BasicBlock * preheader)319 void Loop::SetPreHeaderBlock(BasicBlock* preheader) {
320   if (preheader) {
321     assert(!IsInsideLoop(preheader) && "The preheader block is in the loop");
322     assert(preheader->tail()->opcode() == SpvOpBranch &&
323            "The preheader block does not unconditionally branch to the header "
324            "block");
325     assert(preheader->tail()->GetSingleWordOperand(0) ==
326                GetHeaderBlock()->id() &&
327            "The preheader block does not unconditionally branch to the header "
328            "block");
329   }
330   loop_preheader_ = preheader;
331 }
332 
FindLatchBlock()333 BasicBlock* Loop::FindLatchBlock() {
334   CFG* cfg = context_->cfg();
335 
336   DominatorAnalysis* dominator_analysis =
337       context_->GetDominatorAnalysis(loop_header_->GetParent());
338 
339   // Look at the predecessors of the loop header to find a predecessor block
340   // which is dominated by the loop continue target. There should only be one
341   // block which meets this criteria and this is the latch block, as per the
342   // SPIR-V spec.
343   for (uint32_t block_id : cfg->preds(loop_header_->id())) {
344     if (dominator_analysis->Dominates(loop_continue_->id(), block_id)) {
345       return cfg->block(block_id);
346     }
347   }
348 
349   assert(
350       false &&
351       "Every loop should have a latch block dominated by the continue target");
352   return nullptr;
353 }
354 
GetExitBlocks(std::unordered_set<uint32_t> * exit_blocks) const355 void Loop::GetExitBlocks(std::unordered_set<uint32_t>* exit_blocks) const {
356   CFG* cfg = context_->cfg();
357   exit_blocks->clear();
358 
359   for (uint32_t bb_id : GetBlocks()) {
360     const BasicBlock* bb = cfg->block(bb_id);
361     bb->ForEachSuccessorLabel([exit_blocks, this](uint32_t succ) {
362       if (!IsInsideLoop(succ)) {
363         exit_blocks->insert(succ);
364       }
365     });
366   }
367 }
368 
GetMergingBlocks(std::unordered_set<uint32_t> * merging_blocks) const369 void Loop::GetMergingBlocks(
370     std::unordered_set<uint32_t>* merging_blocks) const {
371   assert(GetMergeBlock() && "This loop is not structured");
372   CFG* cfg = context_->cfg();
373   merging_blocks->clear();
374 
375   std::stack<const BasicBlock*> to_visit;
376   to_visit.push(GetMergeBlock());
377   while (!to_visit.empty()) {
378     const BasicBlock* bb = to_visit.top();
379     to_visit.pop();
380     merging_blocks->insert(bb->id());
381     for (uint32_t pred_id : cfg->preds(bb->id())) {
382       if (!IsInsideLoop(pred_id) && !merging_blocks->count(pred_id)) {
383         to_visit.push(cfg->block(pred_id));
384       }
385     }
386   }
387 }
388 
389 namespace {
390 
IsBasicBlockSafeToClone(IRContext * context,BasicBlock * bb)391 static inline bool IsBasicBlockSafeToClone(IRContext* context, BasicBlock* bb) {
392   for (Instruction& inst : *bb) {
393     if (!inst.IsBranch() && !context->IsCombinatorInstruction(&inst))
394       return false;
395   }
396 
397   return true;
398 }
399 
400 }  // namespace
401 
IsSafeToClone() const402 bool Loop::IsSafeToClone() const {
403   CFG& cfg = *context_->cfg();
404 
405   for (uint32_t bb_id : GetBlocks()) {
406     BasicBlock* bb = cfg.block(bb_id);
407     assert(bb);
408     if (!IsBasicBlockSafeToClone(context_, bb)) return false;
409   }
410 
411   // Look at the merge construct.
412   if (GetHeaderBlock()->GetLoopMergeInst()) {
413     std::unordered_set<uint32_t> blocks;
414     GetMergingBlocks(&blocks);
415     blocks.erase(GetMergeBlock()->id());
416     for (uint32_t bb_id : blocks) {
417       BasicBlock* bb = cfg.block(bb_id);
418       assert(bb);
419       if (!IsBasicBlockSafeToClone(context_, bb)) return false;
420     }
421   }
422 
423   return true;
424 }
425 
IsLCSSA() const426 bool Loop::IsLCSSA() const {
427   CFG* cfg = context_->cfg();
428   analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
429 
430   std::unordered_set<uint32_t> exit_blocks;
431   GetExitBlocks(&exit_blocks);
432 
433   // Declare ir_context so we can capture context_ in the below lambda
434   IRContext* ir_context = context_;
435 
436   for (uint32_t bb_id : GetBlocks()) {
437     for (Instruction& insn : *cfg->block(bb_id)) {
438       // All uses must be either:
439       //  - In the loop;
440       //  - In an exit block and in a phi instruction.
441       if (!def_use_mgr->WhileEachUser(
442               &insn,
443               [&exit_blocks, ir_context, this](Instruction* use) -> bool {
444                 BasicBlock* parent = ir_context->get_instr_block(use);
445                 assert(parent && "Invalid analysis");
446                 if (IsInsideLoop(parent)) return true;
447                 if (use->opcode() != SpvOpPhi) return false;
448                 return exit_blocks.count(parent->id());
449               }))
450         return false;
451     }
452   }
453   return true;
454 }
455 
ShouldHoistInstruction(IRContext * context,Instruction * inst)456 bool Loop::ShouldHoistInstruction(IRContext* context, Instruction* inst) {
457   return AreAllOperandsOutsideLoop(context, inst) &&
458          inst->IsOpcodeCodeMotionSafe();
459 }
460 
AreAllOperandsOutsideLoop(IRContext * context,Instruction * inst)461 bool Loop::AreAllOperandsOutsideLoop(IRContext* context, Instruction* inst) {
462   analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
463   bool all_outside_loop = true;
464 
465   const std::function<void(uint32_t*)> operand_outside_loop =
466       [this, &def_use_mgr, &all_outside_loop](uint32_t* id) {
467         if (this->IsInsideLoop(def_use_mgr->GetDef(*id))) {
468           all_outside_loop = false;
469           return;
470         }
471       };
472 
473   inst->ForEachInId(operand_outside_loop);
474   return all_outside_loop;
475 }
476 
ComputeLoopStructuredOrder(std::vector<BasicBlock * > * ordered_loop_blocks,bool include_pre_header,bool include_merge) const477 void Loop::ComputeLoopStructuredOrder(
478     std::vector<BasicBlock*>* ordered_loop_blocks, bool include_pre_header,
479     bool include_merge) const {
480   CFG& cfg = *context_->cfg();
481 
482   // Reserve the memory: all blocks in the loop + extra if needed.
483   ordered_loop_blocks->reserve(GetBlocks().size() + include_pre_header +
484                                include_merge);
485 
486   if (include_pre_header && GetPreHeaderBlock())
487     ordered_loop_blocks->push_back(loop_preheader_);
488   cfg.ForEachBlockInReversePostOrder(
489       loop_header_, [ordered_loop_blocks, this](BasicBlock* bb) {
490         if (IsInsideLoop(bb)) ordered_loop_blocks->push_back(bb);
491       });
492   if (include_merge && GetMergeBlock())
493     ordered_loop_blocks->push_back(loop_merge_);
494 }
495 
LoopDescriptor(IRContext * context,const Function * f)496 LoopDescriptor::LoopDescriptor(IRContext* context, const Function* f)
497     : loops_(), dummy_top_loop_(nullptr) {
498   PopulateList(context, f);
499 }
500 
~LoopDescriptor()501 LoopDescriptor::~LoopDescriptor() { ClearLoops(); }
502 
PopulateList(IRContext * context,const Function * f)503 void LoopDescriptor::PopulateList(IRContext* context, const Function* f) {
504   DominatorAnalysis* dom_analysis = context->GetDominatorAnalysis(f);
505 
506   ClearLoops();
507 
508   // Post-order traversal of the dominator tree to find all the OpLoopMerge
509   // instructions.
510   DominatorTree& dom_tree = dom_analysis->GetDomTree();
511   for (DominatorTreeNode& node :
512        make_range(dom_tree.post_begin(), dom_tree.post_end())) {
513     Instruction* merge_inst = node.bb_->GetLoopMergeInst();
514     if (merge_inst) {
515       bool all_backedge_unreachable = true;
516       for (uint32_t pid : context->cfg()->preds(node.bb_->id())) {
517         if (dom_analysis->IsReachable(pid) &&
518             dom_analysis->Dominates(node.bb_->id(), pid)) {
519           all_backedge_unreachable = false;
520           break;
521         }
522       }
523       if (all_backedge_unreachable)
524         continue;  // ignore this one, we actually never branch back.
525 
526       // The id of the merge basic block of this loop.
527       uint32_t merge_bb_id = merge_inst->GetSingleWordOperand(0);
528 
529       // The id of the continue basic block of this loop.
530       uint32_t continue_bb_id = merge_inst->GetSingleWordOperand(1);
531 
532       // The merge target of this loop.
533       BasicBlock* merge_bb = context->cfg()->block(merge_bb_id);
534 
535       // The continue target of this loop.
536       BasicBlock* continue_bb = context->cfg()->block(continue_bb_id);
537 
538       // The basic block containing the merge instruction.
539       BasicBlock* header_bb = context->get_instr_block(merge_inst);
540 
541       // Add the loop to the list of all the loops in the function.
542       Loop* current_loop =
543           new Loop(context, dom_analysis, header_bb, continue_bb, merge_bb);
544       loops_.push_back(current_loop);
545 
546       // We have a bottom-up construction, so if this loop has nested-loops,
547       // they are by construction at the tail of the loop list.
548       for (auto itr = loops_.rbegin() + 1; itr != loops_.rend(); ++itr) {
549         Loop* previous_loop = *itr;
550 
551         // If the loop already has a parent, then it has been processed.
552         if (previous_loop->HasParent()) continue;
553 
554         // If the current loop does not dominates the previous loop then it is
555         // not nested loop.
556         if (!dom_analysis->Dominates(header_bb,
557                                      previous_loop->GetHeaderBlock()))
558           continue;
559         // If the current loop merge dominates the previous loop then it is
560         // not nested loop.
561         if (dom_analysis->Dominates(merge_bb, previous_loop->GetHeaderBlock()))
562           continue;
563 
564         current_loop->AddNestedLoop(previous_loop);
565       }
566       DominatorTreeNode* dom_merge_node = dom_tree.GetTreeNode(merge_bb);
567       for (DominatorTreeNode& loop_node :
568            make_range(node.df_begin(), node.df_end())) {
569         // Check if we are in the loop.
570         if (dom_tree.Dominates(dom_merge_node, &loop_node)) continue;
571         current_loop->AddBasicBlock(loop_node.bb_);
572         basic_block_to_loop_.insert(
573             std::make_pair(loop_node.bb_->id(), current_loop));
574       }
575     }
576   }
577   for (Loop* loop : loops_) {
578     if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop);
579   }
580 }
581 
GetLoopsInBinaryLayoutOrder()582 std::vector<Loop*> LoopDescriptor::GetLoopsInBinaryLayoutOrder() {
583   std::vector<uint32_t> ids{};
584 
585   for (size_t i = 0; i < NumLoops(); ++i) {
586     ids.push_back(GetLoopByIndex(i).GetHeaderBlock()->id());
587   }
588 
589   std::vector<Loop*> loops{};
590   if (!ids.empty()) {
591     auto function = GetLoopByIndex(0).GetHeaderBlock()->GetParent();
592     for (const auto& block : *function) {
593       auto block_id = block.id();
594 
595       auto element = std::find(std::begin(ids), std::end(ids), block_id);
596       if (element != std::end(ids)) {
597         loops.push_back(&GetLoopByIndex(element - std::begin(ids)));
598       }
599     }
600   }
601 
602   return loops;
603 }
604 
FindConditionBlock() const605 BasicBlock* Loop::FindConditionBlock() const {
606   if (!loop_merge_) {
607     return nullptr;
608   }
609   BasicBlock* condition_block = nullptr;
610 
611   uint32_t in_loop_pred = 0;
612   for (uint32_t p : context_->cfg()->preds(loop_merge_->id())) {
613     if (IsInsideLoop(p)) {
614       if (in_loop_pred) {
615         // 2 in-loop predecessors.
616         return nullptr;
617       }
618       in_loop_pred = p;
619     }
620   }
621   if (!in_loop_pred) {
622     // Merge block is unreachable.
623     return nullptr;
624   }
625 
626   BasicBlock* bb = context_->cfg()->block(in_loop_pred);
627 
628   if (!bb) return nullptr;
629 
630   const Instruction& branch = *bb->ctail();
631 
632   // Make sure the branch is a conditional branch.
633   if (branch.opcode() != SpvOpBranchConditional) return nullptr;
634 
635   // Make sure one of the two possible branches is to the merge block.
636   if (branch.GetSingleWordInOperand(1) == loop_merge_->id() ||
637       branch.GetSingleWordInOperand(2) == loop_merge_->id()) {
638     condition_block = bb;
639   }
640 
641   return condition_block;
642 }
643 
FindNumberOfIterations(const Instruction * induction,const Instruction * branch_inst,size_t * iterations_out,int64_t * step_value_out,int64_t * init_value_out) const644 bool Loop::FindNumberOfIterations(const Instruction* induction,
645                                   const Instruction* branch_inst,
646                                   size_t* iterations_out,
647                                   int64_t* step_value_out,
648                                   int64_t* init_value_out) const {
649   // From the branch instruction find the branch condition.
650   analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
651 
652   // Condition instruction from the OpConditionalBranch.
653   Instruction* condition =
654       def_use_manager->GetDef(branch_inst->GetSingleWordOperand(0));
655 
656   assert(IsSupportedCondition(condition->opcode()));
657 
658   // Get the constant manager from the ir context.
659   analysis::ConstantManager* const_manager = context_->get_constant_mgr();
660 
661   // Find the constant value used by the condition variable. Exit out if it
662   // isn't a constant int.
663   const analysis::Constant* upper_bound =
664       const_manager->FindDeclaredConstant(condition->GetSingleWordOperand(3));
665   if (!upper_bound) return false;
666 
667   // Must be integer because of the opcode on the condition.
668   int64_t condition_value = 0;
669 
670   const analysis::Integer* type =
671       upper_bound->AsIntConstant()->type()->AsInteger();
672 
673   if (type->width() > 32) {
674     return false;
675   }
676 
677   if (type->IsSigned()) {
678     condition_value = upper_bound->AsIntConstant()->GetS32BitValue();
679   } else {
680     condition_value = upper_bound->AsIntConstant()->GetU32BitValue();
681   }
682 
683   // Find the instruction which is stepping through the loop.
684   Instruction* step_inst = GetInductionStepOperation(induction);
685   if (!step_inst) return false;
686 
687   // Find the constant value used by the condition variable.
688   const analysis::Constant* step_constant =
689       const_manager->FindDeclaredConstant(step_inst->GetSingleWordOperand(3));
690   if (!step_constant) return false;
691 
692   // Must be integer because of the opcode on the condition.
693   int64_t step_value = 0;
694 
695   const analysis::Integer* step_type =
696       step_constant->AsIntConstant()->type()->AsInteger();
697 
698   if (step_type->IsSigned()) {
699     step_value = step_constant->AsIntConstant()->GetS32BitValue();
700   } else {
701     step_value = step_constant->AsIntConstant()->GetU32BitValue();
702   }
703 
704   // If this is a subtraction step we should negate the step value.
705   if (step_inst->opcode() == SpvOp::SpvOpISub) {
706     step_value = -step_value;
707   }
708 
709   // Find the inital value of the loop and make sure it is a constant integer.
710   int64_t init_value = 0;
711   if (!GetInductionInitValue(induction, &init_value)) return false;
712 
713   // If iterations is non null then store the value in that.
714   int64_t num_itrs = GetIterations(condition->opcode(), condition_value,
715                                    init_value, step_value);
716 
717   // If the loop body will not be reached return false.
718   if (num_itrs <= 0) {
719     return false;
720   }
721 
722   if (iterations_out) {
723     assert(static_cast<size_t>(num_itrs) <= std::numeric_limits<size_t>::max());
724     *iterations_out = static_cast<size_t>(num_itrs);
725   }
726 
727   if (step_value_out) {
728     *step_value_out = step_value;
729   }
730 
731   if (init_value_out) {
732     *init_value_out = init_value;
733   }
734 
735   return true;
736 }
737 
738 // We retrieve the number of iterations using the following formula, diff /
739 // |step_value| where diff is calculated differently according to the
740 // |condition| and uses the |condition_value| and |init_value|. If diff /
741 // |step_value| is NOT cleanly divisable then we add one to the sum.
GetIterations(SpvOp condition,int64_t condition_value,int64_t init_value,int64_t step_value) const742 int64_t Loop::GetIterations(SpvOp condition, int64_t condition_value,
743                             int64_t init_value, int64_t step_value) const {
744   int64_t diff = 0;
745 
746   switch (condition) {
747     case SpvOp::SpvOpSLessThan:
748     case SpvOp::SpvOpULessThan: {
749       // If the condition is not met to begin with the loop will never iterate.
750       if (!(init_value < condition_value)) return 0;
751 
752       diff = condition_value - init_value;
753 
754       // If the operation is a less then operation then the diff and step must
755       // have the same sign otherwise the induction will never cross the
756       // condition (either never true or always true).
757       if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
758         return 0;
759       }
760 
761       break;
762     }
763     case SpvOp::SpvOpSGreaterThan:
764     case SpvOp::SpvOpUGreaterThan: {
765       // If the condition is not met to begin with the loop will never iterate.
766       if (!(init_value > condition_value)) return 0;
767 
768       diff = init_value - condition_value;
769 
770       // If the operation is a greater than operation then the diff and step
771       // must have opposite signs. Otherwise the condition will always be true
772       // or will never be true.
773       if ((diff < 0 && step_value < 0) || (diff > 0 && step_value > 0)) {
774         return 0;
775       }
776 
777       break;
778     }
779 
780     case SpvOp::SpvOpSGreaterThanEqual:
781     case SpvOp::SpvOpUGreaterThanEqual: {
782       // If the condition is not met to begin with the loop will never iterate.
783       if (!(init_value >= condition_value)) return 0;
784 
785       // We subract one to make it the same as SpvOpGreaterThan as it is
786       // functionally equivalent.
787       diff = init_value - (condition_value - 1);
788 
789       // If the operation is a greater than operation then the diff and step
790       // must have opposite signs. Otherwise the condition will always be true
791       // or will never be true.
792       if ((diff > 0 && step_value > 0) || (diff < 0 && step_value < 0)) {
793         return 0;
794       }
795 
796       break;
797     }
798 
799     case SpvOp::SpvOpSLessThanEqual:
800     case SpvOp::SpvOpULessThanEqual: {
801       // If the condition is not met to begin with the loop will never iterate.
802       if (!(init_value <= condition_value)) return 0;
803 
804       // We add one to make it the same as SpvOpLessThan as it is functionally
805       // equivalent.
806       diff = (condition_value + 1) - init_value;
807 
808       // If the operation is a less than operation then the diff and step must
809       // have the same sign otherwise the induction will never cross the
810       // condition (either never true or always true).
811       if ((diff < 0 && step_value > 0) || (diff > 0 && step_value < 0)) {
812         return 0;
813       }
814 
815       break;
816     }
817 
818     default:
819       assert(false &&
820              "Could not retrieve number of iterations from the loop condition. "
821              "Condition is not supported.");
822   }
823 
824   // Take the abs of - step values.
825   step_value = llabs(step_value);
826   diff = llabs(diff);
827   int64_t result = diff / step_value;
828 
829   if (diff % step_value != 0) {
830     result += 1;
831   }
832   return result;
833 }
834 
835 // Returns the list of induction variables within the loop.
GetInductionVariables(std::vector<Instruction * > & induction_variables) const836 void Loop::GetInductionVariables(
837     std::vector<Instruction*>& induction_variables) const {
838   for (Instruction& inst : *loop_header_) {
839     if (inst.opcode() == SpvOp::SpvOpPhi) {
840       induction_variables.push_back(&inst);
841     }
842   }
843 }
844 
FindConditionVariable(const BasicBlock * condition_block) const845 Instruction* Loop::FindConditionVariable(
846     const BasicBlock* condition_block) const {
847   // Find the branch instruction.
848   const Instruction& branch_inst = *condition_block->ctail();
849 
850   Instruction* induction = nullptr;
851   // Verify that the branch instruction is a conditional branch.
852   if (branch_inst.opcode() == SpvOp::SpvOpBranchConditional) {
853     // From the branch instruction find the branch condition.
854     analysis::DefUseManager* def_use_manager = context_->get_def_use_mgr();
855 
856     // Find the instruction representing the condition used in the conditional
857     // branch.
858     Instruction* condition =
859         def_use_manager->GetDef(branch_inst.GetSingleWordOperand(0));
860 
861     // Ensure that the condition is a less than operation.
862     if (condition && IsSupportedCondition(condition->opcode())) {
863       // The left hand side operand of the operation.
864       Instruction* variable_inst =
865           def_use_manager->GetDef(condition->GetSingleWordOperand(2));
866 
867       // Make sure the variable instruction used is a phi.
868       if (!variable_inst || variable_inst->opcode() != SpvOpPhi) return nullptr;
869 
870       // Make sure the phi instruction only has two incoming blocks. Each
871       // incoming block will be represented by two in operands in the phi
872       // instruction, the value and the block which that value came from. We
873       // assume the cannocalised phi will have two incoming values, one from the
874       // preheader and one from the continue block.
875       size_t max_supported_operands = 4;
876       if (variable_inst->NumInOperands() == max_supported_operands) {
877         // The operand index of the first incoming block label.
878         uint32_t operand_label_1 = 1;
879 
880         // The operand index of the second incoming block label.
881         uint32_t operand_label_2 = 3;
882 
883         // Make sure one of them is the preheader.
884         if (!IsInsideLoop(
885                 variable_inst->GetSingleWordInOperand(operand_label_1)) &&
886             !IsInsideLoop(
887                 variable_inst->GetSingleWordInOperand(operand_label_2))) {
888           return nullptr;
889         }
890 
891         // And make sure that the other is the latch block.
892         if (variable_inst->GetSingleWordInOperand(operand_label_1) !=
893                 loop_latch_->id() &&
894             variable_inst->GetSingleWordInOperand(operand_label_2) !=
895                 loop_latch_->id()) {
896           return nullptr;
897         }
898       } else {
899         return nullptr;
900       }
901 
902       if (!FindNumberOfIterations(variable_inst, &branch_inst, nullptr))
903         return nullptr;
904       induction = variable_inst;
905     }
906   }
907 
908   return induction;
909 }
910 
CreatePreHeaderBlocksIfMissing()911 bool LoopDescriptor::CreatePreHeaderBlocksIfMissing() {
912   auto modified = false;
913 
914   for (auto& loop : *this) {
915     if (!loop.GetPreHeaderBlock()) {
916       modified = true;
917       // TODO(1841): Handle failure to create pre-header.
918       loop.GetOrCreatePreHeaderBlock();
919     }
920   }
921 
922   return modified;
923 }
924 
925 // Add and remove loops which have been marked for addition and removal to
926 // maintain the state of the loop descriptor class.
PostModificationCleanup()927 void LoopDescriptor::PostModificationCleanup() {
928   LoopContainerType loops_to_remove_;
929   for (Loop* loop : loops_) {
930     if (loop->IsMarkedForRemoval()) {
931       loops_to_remove_.push_back(loop);
932       if (loop->HasParent()) {
933         loop->GetParent()->RemoveChildLoop(loop);
934       }
935     }
936   }
937 
938   for (Loop* loop : loops_to_remove_) {
939     loops_.erase(std::find(loops_.begin(), loops_.end(), loop));
940     delete loop;
941   }
942 
943   for (auto& pair : loops_to_add_) {
944     Loop* parent = pair.first;
945     std::unique_ptr<Loop> loop = std::move(pair.second);
946 
947     if (parent) {
948       loop->SetParent(nullptr);
949       parent->AddNestedLoop(loop.get());
950 
951       for (uint32_t block_id : loop->GetBlocks()) {
952         parent->AddBasicBlock(block_id);
953       }
954     }
955 
956     loops_.emplace_back(loop.release());
957   }
958 
959   loops_to_add_.clear();
960 }
961 
ClearLoops()962 void LoopDescriptor::ClearLoops() {
963   for (Loop* loop : loops_) {
964     delete loop;
965   }
966   loops_.clear();
967 }
968 
969 // Adds a new loop nest to the descriptor set.
AddLoopNest(std::unique_ptr<Loop> new_loop)970 Loop* LoopDescriptor::AddLoopNest(std::unique_ptr<Loop> new_loop) {
971   Loop* loop = new_loop.release();
972   if (!loop->HasParent()) dummy_top_loop_.nested_loops_.push_back(loop);
973   // Iterate from inner to outer most loop, adding basic block to loop mapping
974   // as we go.
975   for (Loop& current_loop :
976        make_range(iterator::begin(loop), iterator::end(nullptr))) {
977     loops_.push_back(&current_loop);
978     for (uint32_t bb_id : current_loop.GetBlocks())
979       basic_block_to_loop_.insert(std::make_pair(bb_id, &current_loop));
980   }
981 
982   return loop;
983 }
984 
RemoveLoop(Loop * loop)985 void LoopDescriptor::RemoveLoop(Loop* loop) {
986   Loop* parent = loop->GetParent() ? loop->GetParent() : &dummy_top_loop_;
987   parent->nested_loops_.erase(std::find(parent->nested_loops_.begin(),
988                                         parent->nested_loops_.end(), loop));
989   std::for_each(
990       loop->nested_loops_.begin(), loop->nested_loops_.end(),
991       [loop](Loop* sub_loop) { sub_loop->SetParent(loop->GetParent()); });
992   parent->nested_loops_.insert(parent->nested_loops_.end(),
993                                loop->nested_loops_.begin(),
994                                loop->nested_loops_.end());
995   for (uint32_t bb_id : loop->GetBlocks()) {
996     Loop* l = FindLoopForBasicBlock(bb_id);
997     if (l == loop) {
998       SetBasicBlockToLoop(bb_id, l->GetParent());
999     } else {
1000       ForgetBasicBlock(bb_id);
1001     }
1002   }
1003 
1004   LoopContainerType::iterator it =
1005       std::find(loops_.begin(), loops_.end(), loop);
1006   assert(it != loops_.end());
1007   delete loop;
1008   loops_.erase(it);
1009 }
1010 
1011 }  // namespace opt
1012 }  // namespace spvtools
1013