• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/loop_unswitch_pass.h"
16 
17 #include <functional>
18 #include <list>
19 #include <memory>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "source/opt/basic_block.h"
26 #include "source/opt/dominator_tree.h"
27 #include "source/opt/fold.h"
28 #include "source/opt/function.h"
29 #include "source/opt/instruction.h"
30 #include "source/opt/ir_builder.h"
31 #include "source/opt/ir_context.h"
32 #include "source/opt/loop_descriptor.h"
33 #include "source/opt/loop_utils.h"
34 
35 namespace spvtools {
36 namespace opt {
37 namespace {
38 constexpr uint32_t kTypePointerStorageClassInIdx = 0;
39 
40 // This class handle the unswitch procedure for a given loop.
41 // The unswitch will not happen if:
42 //  - The loop has any instruction that will prevent it;
43 //  - The loop invariant condition is not uniform.
44 class LoopUnswitch {
45  public:
LoopUnswitch(IRContext * context,Function * function,Loop * loop,LoopDescriptor * loop_desc)46   LoopUnswitch(IRContext* context, Function* function, Loop* loop,
47                LoopDescriptor* loop_desc)
48       : function_(function),
49         loop_(loop),
50         loop_desc_(*loop_desc),
51         context_(context),
52         switch_block_(nullptr) {}
53 
54   // Returns true if the loop can be unswitched.
55   // Can be unswitch if:
56   //  - The loop has no instructions that prevents it (such as barrier);
57   //  - The loop has one conditional branch or switch that do not depends on the
58   //  loop;
59   //  - The loop invariant condition is uniform;
CanUnswitchLoop()60   bool CanUnswitchLoop() {
61     if (switch_block_) return true;
62     if (loop_->IsSafeToClone()) return false;
63 
64     CFG& cfg = *context_->cfg();
65 
66     for (uint32_t bb_id : loop_->GetBlocks()) {
67       BasicBlock* bb = cfg.block(bb_id);
68       if (loop_->GetLatchBlock() == bb) {
69         continue;
70       }
71 
72       if (bb->terminator()->IsBranch() &&
73           bb->terminator()->opcode() != spv::Op::OpBranch) {
74         if (IsConditionNonConstantLoopInvariant(bb->terminator())) {
75           switch_block_ = bb;
76           break;
77         }
78       }
79     }
80 
81     return switch_block_;
82   }
83 
84   // Return the iterator to the basic block |bb|.
FindBasicBlockPosition(BasicBlock * bb_to_find)85   Function::iterator FindBasicBlockPosition(BasicBlock* bb_to_find) {
86     Function::iterator it = function_->FindBlock(bb_to_find->id());
87     assert(it != function_->end() && "Basic Block not found");
88     return it;
89   }
90 
91   // Creates a new basic block and insert it into the function |fn| at the
92   // position |ip|. This function preserves the def/use and instr to block
93   // managers.
CreateBasicBlock(Function::iterator ip)94   BasicBlock* CreateBasicBlock(Function::iterator ip) {
95     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
96 
97     // TODO(1841): Handle id overflow.
98     BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
99         new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
100             context_, spv::Op::OpLabel, 0, context_->TakeNextId(), {})))));
101     bb->SetParent(function_);
102     def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
103     context_->set_instr_block(bb->GetLabelInst(), bb);
104 
105     return bb;
106   }
107 
GetValueForDefaultPathForSwitch(Instruction * switch_inst)108   Instruction* GetValueForDefaultPathForSwitch(Instruction* switch_inst) {
109     assert(switch_inst->opcode() == spv::Op::OpSwitch &&
110            "The given instructoin must be an OpSwitch.");
111 
112     // Find a value that can be used to select the default path.
113     // If none are possible, then it will just use 0.  The value does not matter
114     // because this path will never be taken because the new switch outside of
115     // the loop cannot select this path either.
116     std::vector<uint32_t> existing_values;
117     for (uint32_t i = 2; i < switch_inst->NumInOperands(); i += 2) {
118       existing_values.push_back(switch_inst->GetSingleWordInOperand(i));
119     }
120     std::sort(existing_values.begin(), existing_values.end());
121     uint32_t value_for_default_path = 0;
122     if (existing_values.size() < std::numeric_limits<uint32_t>::max()) {
123       for (value_for_default_path = 0;
124            value_for_default_path < existing_values.size();
125            value_for_default_path++) {
126         if (existing_values[value_for_default_path] != value_for_default_path) {
127           break;
128         }
129       }
130     }
131     InstructionBuilder builder(
132         context_, static_cast<Instruction*>(nullptr),
133         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
134     return builder.GetUintConstant(value_for_default_path);
135   }
136 
137   // Unswitches |loop_|.
PerformUnswitch()138   void PerformUnswitch() {
139     assert(CanUnswitchLoop() &&
140            "Cannot unswitch if there is not constant condition");
141     assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
142     assert(loop_->IsLCSSA() && "This loop is not in LCSSA form");
143 
144     CFG& cfg = *context_->cfg();
145     DominatorTree* dom_tree =
146         &context_->GetDominatorAnalysis(function_)->GetDomTree();
147     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
148     LoopUtils loop_utils(context_, loop_);
149 
150     //////////////////////////////////////////////////////////////////////////////
151     // Step 1: Create the if merge block for structured modules.
152     //    To do so, the |loop_| merge block will become the if's one and we
153     //    create a merge for the loop. This will limit the amount of duplicated
154     //    code the structured control flow imposes.
155     //    For non structured program, the new loop will be connected to
156     //    the old loop's exit blocks.
157     //////////////////////////////////////////////////////////////////////////////
158 
159     // Get the merge block if it exists.
160     BasicBlock* if_merge_block = loop_->GetMergeBlock();
161     // The merge block is only created if the loop has a unique exit block. We
162     // have this guarantee for structured loops, for compute loop it will
163     // trivially help maintain both a structured-like form and LCSAA.
164     BasicBlock* loop_merge_block =
165         if_merge_block
166             ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
167             : nullptr;
168     if (loop_merge_block) {
169       // Add the instruction and update managers.
170       InstructionBuilder builder(
171           context_, loop_merge_block,
172           IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
173       builder.AddBranch(if_merge_block->id());
174       builder.SetInsertPoint(&*loop_merge_block->begin());
175       cfg.RegisterBlock(loop_merge_block);
176       def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
177       // Update CFG.
178       if_merge_block->ForEachPhiInst(
179           [loop_merge_block, &builder, this](Instruction* phi) {
180             Instruction* cloned = phi->Clone(context_);
181             cloned->SetResultId(TakeNextId());
182             builder.AddInstruction(std::unique_ptr<Instruction>(cloned));
183             phi->SetInOperand(0, {cloned->result_id()});
184             phi->SetInOperand(1, {loop_merge_block->id()});
185             for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
186               phi->RemoveInOperand(j);
187           });
188       // Copy the predecessor list (will get invalidated otherwise).
189       std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
190       for (uint32_t pid : preds) {
191         if (pid == loop_merge_block->id()) continue;
192         BasicBlock* p_bb = cfg.block(pid);
193         p_bb->ForEachSuccessorLabel(
194             [if_merge_block, loop_merge_block](uint32_t* id) {
195               if (*id == if_merge_block->id()) *id = loop_merge_block->id();
196             });
197         cfg.AddEdge(pid, loop_merge_block->id());
198       }
199       cfg.RemoveNonExistingEdges(if_merge_block->id());
200       // Update loop descriptor.
201       if (Loop* ploop = loop_->GetParent()) {
202         ploop->AddBasicBlock(loop_merge_block);
203         loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop);
204       }
205       // Update the dominator tree.
206       DominatorTreeNode* loop_merge_dtn =
207           dom_tree->GetOrInsertNode(loop_merge_block);
208       DominatorTreeNode* if_merge_block_dtn =
209           dom_tree->GetOrInsertNode(if_merge_block);
210       loop_merge_dtn->parent_ = if_merge_block_dtn->parent_;
211       loop_merge_dtn->children_.push_back(if_merge_block_dtn);
212       loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn);
213       if_merge_block_dtn->parent_->children_.erase(std::find(
214           if_merge_block_dtn->parent_->children_.begin(),
215           if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn));
216 
217       loop_->SetMergeBlock(loop_merge_block);
218     }
219 
220     ////////////////////////////////////////////////////////////////////////////
221     // Step 2: Build a new preheader for |loop_|, use the old one
222     //         for the invariant branch.
223     ////////////////////////////////////////////////////////////////////////////
224 
225     BasicBlock* if_block = loop_->GetPreHeaderBlock();
226     // If this preheader is the parent loop header,
227     // we need to create a dedicated block for the if.
228     BasicBlock* loop_pre_header =
229         CreateBasicBlock(++FindBasicBlockPosition(if_block));
230     InstructionBuilder(
231         context_, loop_pre_header,
232         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping)
233         .AddBranch(loop_->GetHeaderBlock()->id());
234 
235     if_block->tail()->SetInOperand(0, {loop_pre_header->id()});
236 
237     // Update loop descriptor.
238     if (Loop* ploop = loop_desc_[if_block]) {
239       ploop->AddBasicBlock(loop_pre_header);
240       loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop);
241     }
242 
243     // Update the CFG.
244     cfg.RegisterBlock(loop_pre_header);
245     def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst());
246     cfg.AddEdge(if_block->id(), loop_pre_header->id());
247     cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id());
248 
249     loop_->GetHeaderBlock()->ForEachPhiInst(
250         [loop_pre_header, if_block](Instruction* phi) {
251           phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) {
252             if (*id == if_block->id()) {
253               *id = loop_pre_header->id();
254             }
255           });
256         });
257     loop_->SetPreHeaderBlock(loop_pre_header);
258 
259     // Update the dominator tree.
260     DominatorTreeNode* loop_pre_header_dtn =
261         dom_tree->GetOrInsertNode(loop_pre_header);
262     DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block);
263     loop_pre_header_dtn->parent_ = if_block_dtn;
264     assert(
265         if_block_dtn->children_.size() == 1 &&
266         "A loop preheader should only have the header block as a child in the "
267         "dominator tree");
268     loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]);
269     if_block_dtn->children_.clear();
270     if_block_dtn->children_.push_back(loop_pre_header_dtn);
271 
272     // Make domination queries valid.
273     dom_tree->ResetDFNumbering();
274 
275     // Compute an ordered list of basic block to clone: loop blocks + pre-header
276     // + merge block.
277     loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true);
278 
279     /////////////////////////////
280     // Do the actual unswitch: //
281     //   - Clone the loop      //
282     //   - Connect exits       //
283     //   - Specialize the loop //
284     /////////////////////////////
285 
286     Instruction* iv_condition = &*switch_block_->tail();
287     spv::Op iv_opcode = iv_condition->opcode();
288     Instruction* condition =
289         def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]);
290 
291     analysis::ConstantManager* cst_mgr = context_->get_constant_mgr();
292     const analysis::Type* cond_type =
293         context_->get_type_mgr()->GetType(condition->type_id());
294 
295     // Build the list of value for which we need to clone and specialize the
296     // loop.
297     std::vector<std::pair<Instruction*, BasicBlock*>> constant_branch;
298     // Special case for the original loop
299     Instruction* original_loop_constant_value;
300     if (iv_opcode == spv::Op::OpBranchConditional) {
301       constant_branch.emplace_back(
302           cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})),
303           nullptr);
304       original_loop_constant_value =
305           cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1}));
306     } else {
307       // We are looking to take the default branch, so we can't provide a
308       // specific value.
309       original_loop_constant_value =
310           GetValueForDefaultPathForSwitch(iv_condition);
311 
312       for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
313         constant_branch.emplace_back(
314             cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(
315                 cond_type, iv_condition->GetInOperand(i).words)),
316             nullptr);
317       }
318     }
319 
320     // Get the loop landing pads.
321     std::unordered_set<uint32_t> if_merging_blocks;
322     std::function<bool(uint32_t)> is_from_original_loop;
323     if (loop_->GetHeaderBlock()->GetLoopMergeInst()) {
324       if_merging_blocks.insert(if_merge_block->id());
325       is_from_original_loop = [this](uint32_t id) {
326         return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id;
327       };
328     } else {
329       loop_->GetExitBlocks(&if_merging_blocks);
330       is_from_original_loop = [this](uint32_t id) {
331         return loop_->IsInsideLoop(id);
332       };
333     }
334 
335     for (auto& specialisation_pair : constant_branch) {
336       Instruction* specialisation_value = specialisation_pair.first;
337       //////////////////////////////////////////////////////////
338       // Step 3: Duplicate |loop_|.
339       //////////////////////////////////////////////////////////
340       LoopUtils::LoopCloningResult clone_result;
341 
342       Loop* cloned_loop =
343           loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
344       specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
345 
346       ////////////////////////////////////
347       // Step 4: Specialize the loop.   //
348       ////////////////////////////////////
349 
350       {
351         SpecializeLoop(cloned_loop, condition, specialisation_value);
352 
353         ///////////////////////////////////////////////////////////
354         // Step 5: Connect convergent edges to the landing pads. //
355         ///////////////////////////////////////////////////////////
356 
357         for (uint32_t merge_bb_id : if_merging_blocks) {
358           BasicBlock* merge = context_->cfg()->block(merge_bb_id);
359           // We are in LCSSA so we only care about phi instructions.
360           merge->ForEachPhiInst(
361               [is_from_original_loop, &clone_result](Instruction* phi) {
362                 uint32_t num_in_operands = phi->NumInOperands();
363                 for (uint32_t i = 0; i < num_in_operands; i += 2) {
364                   uint32_t pred = phi->GetSingleWordInOperand(i + 1);
365                   if (is_from_original_loop(pred)) {
366                     pred = clone_result.value_map_.at(pred);
367                     uint32_t incoming_value_id = phi->GetSingleWordInOperand(i);
368                     // Not all the incoming values are coming from the loop.
369                     ValueMapTy::iterator new_value =
370                         clone_result.value_map_.find(incoming_value_id);
371                     if (new_value != clone_result.value_map_.end()) {
372                       incoming_value_id = new_value->second;
373                     }
374                     phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}});
375                     phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}});
376                   }
377                 }
378               });
379         }
380       }
381       function_->AddBasicBlocks(clone_result.cloned_bb_.begin(),
382                                 clone_result.cloned_bb_.end(),
383                                 ++FindBasicBlockPosition(if_block));
384     }
385 
386     // Specialize the existing loop.
387     SpecializeLoop(loop_, condition, original_loop_constant_value);
388     BasicBlock* original_loop_target = loop_->GetPreHeaderBlock();
389 
390     /////////////////////////////////////
391     // Finally: connect the new loops. //
392     /////////////////////////////////////
393 
394     // Delete the old jump
395     context_->KillInst(&*if_block->tail());
396     InstructionBuilder builder(context_, if_block);
397     if (iv_opcode == spv::Op::OpBranchConditional) {
398       assert(constant_branch.size() == 1);
399       builder.AddConditionalBranch(
400           condition->result_id(), original_loop_target->id(),
401           constant_branch[0].second->id(),
402           if_merge_block ? if_merge_block->id() : kInvalidId);
403     } else {
404       std::vector<std::pair<Operand::OperandData, uint32_t>> targets;
405       for (auto& t : constant_branch) {
406         targets.emplace_back(t.first->GetInOperand(0).words, t.second->id());
407       }
408 
409       builder.AddSwitch(condition->result_id(), original_loop_target->id(),
410                         targets,
411                         if_merge_block ? if_merge_block->id() : kInvalidId);
412     }
413 
414     switch_block_ = nullptr;
415     ordered_loop_blocks_.clear();
416 
417     context_->InvalidateAnalysesExceptFor(
418         IRContext::Analysis::kAnalysisLoopAnalysis);
419   }
420 
421  private:
422   using ValueMapTy = std::unordered_map<uint32_t, uint32_t>;
423   using BlockMapTy = std::unordered_map<uint32_t, BasicBlock*>;
424 
425   Function* function_;
426   Loop* loop_;
427   LoopDescriptor& loop_desc_;
428   IRContext* context_;
429 
430   BasicBlock* switch_block_;
431   // Map between instructions and if they are dynamically uniform.
432   std::unordered_map<uint32_t, bool> dynamically_uniform_;
433   // The loop basic blocks in structured order.
434   std::vector<BasicBlock*> ordered_loop_blocks_;
435 
436   // Returns the next usable id for the context.
TakeNextId()437   uint32_t TakeNextId() {
438     // TODO(1841): Handle id overflow.
439     return context_->TakeNextId();
440   }
441 
442   // Simplifies |loop| assuming the instruction |to_version_insn| takes the
443   // value |cst_value|. |block_range| is an iterator range returning the loop
444   // basic blocks in a structured order (dominator first).
445   // The function will ignore basic blocks returned by |block_range| if they
446   // does not belong to the loop.
447   // The set |dead_blocks| will contain all the dead basic blocks.
448   //
449   // Requirements:
450   //   - |loop| must be in the LCSSA form;
451   //   - |cst_value| must be constant.
SpecializeLoop(Loop * loop,Instruction * to_version_insn,Instruction * cst_value)452   void SpecializeLoop(Loop* loop, Instruction* to_version_insn,
453                       Instruction* cst_value) {
454     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
455 
456     std::function<bool(uint32_t)> ignore_node;
457     ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); };
458 
459     std::vector<std::pair<Instruction*, uint32_t>> use_list;
460     def_use_mgr->ForEachUse(to_version_insn,
461                             [&use_list, &ignore_node, this](
462                                 Instruction* inst, uint32_t operand_index) {
463                               BasicBlock* bb = context_->get_instr_block(inst);
464 
465                               if (!bb || ignore_node(bb->id())) {
466                                 // Out of the loop, the specialization does not
467                                 // apply any more.
468                                 return;
469                               }
470                               use_list.emplace_back(inst, operand_index);
471                             });
472 
473     // First pass: inject the specialized value into the loop (and only the
474     // loop).
475     for (auto use : use_list) {
476       Instruction* inst = use.first;
477       uint32_t operand_index = use.second;
478 
479       // To also handle switch, cst_value can be nullptr: this case
480       // means that we are looking to branch to the default target of
481       // the switch. We don't actually know its value so we don't touch
482       // it if it not a switch.
483       assert(cst_value && "We do not have a value to use.");
484       inst->SetOperand(operand_index, {cst_value->result_id()});
485       def_use_mgr->AnalyzeInstUse(inst);
486     }
487   }
488 
489   // Returns true if |var| is dynamically uniform.
490   // Note: this is currently approximated as uniform.
IsDynamicallyUniform(Instruction * var,const BasicBlock * entry,const DominatorTree & post_dom_tree)491   bool IsDynamicallyUniform(Instruction* var, const BasicBlock* entry,
492                             const DominatorTree& post_dom_tree) {
493     assert(post_dom_tree.IsPostDominator());
494     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
495 
496     auto it = dynamically_uniform_.find(var->result_id());
497 
498     if (it != dynamically_uniform_.end()) return it->second;
499 
500     analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr();
501 
502     bool& is_uniform = dynamically_uniform_[var->result_id()];
503     is_uniform = false;
504 
505     dec_mgr->WhileEachDecoration(var->result_id(),
506                                  uint32_t(spv::Decoration::Uniform),
507                                  [&is_uniform](const Instruction&) {
508                                    is_uniform = true;
509                                    return false;
510                                  });
511     if (is_uniform) {
512       return is_uniform;
513     }
514 
515     BasicBlock* parent = context_->get_instr_block(var);
516     if (!parent) {
517       return is_uniform = true;
518     }
519 
520     if (!post_dom_tree.Dominates(parent->id(), entry->id())) {
521       return is_uniform = false;
522     }
523     if (var->opcode() == spv::Op::OpLoad) {
524       const uint32_t PtrTypeId =
525           def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id();
526       const Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId);
527       auto storage_class = spv::StorageClass(
528           PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx));
529       if (storage_class != spv::StorageClass::Uniform &&
530           storage_class != spv::StorageClass::UniformConstant) {
531         return is_uniform = false;
532       }
533     } else {
534       if (!context_->IsCombinatorInstruction(var)) {
535         return is_uniform = false;
536       }
537     }
538 
539     return is_uniform = var->WhileEachInId([entry, &post_dom_tree,
540                                             this](const uint32_t* id) {
541       return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id),
542                                   entry, post_dom_tree);
543     });
544   }
545 
546   // Returns true if |insn| is not a constant, but is loop invariant and
547   // dynamically uniform.
IsConditionNonConstantLoopInvariant(Instruction * insn)548   bool IsConditionNonConstantLoopInvariant(Instruction* insn) {
549     assert(insn->IsBranch());
550     assert(insn->opcode() != spv::Op::OpBranch);
551     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
552 
553     Instruction* condition = def_use_mgr->GetDef(insn->GetOperand(0).words[0]);
554     if (condition->IsConstant()) {
555       return false;
556     }
557 
558     if (loop_->IsInsideLoop(condition)) {
559       return false;
560     }
561 
562     return IsDynamicallyUniform(
563         condition, function_->entry().get(),
564         context_->GetPostDominatorAnalysis(function_)->GetDomTree());
565   }
566 };
567 
568 }  // namespace
569 
Process()570 Pass::Status LoopUnswitchPass::Process() {
571   bool modified = false;
572   Module* module = context()->module();
573 
574   // Process each function in the module
575   for (Function& f : *module) {
576     modified |= ProcessFunction(&f);
577   }
578 
579   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
580 }
581 
ProcessFunction(Function * f)582 bool LoopUnswitchPass::ProcessFunction(Function* f) {
583   bool modified = false;
584   std::unordered_set<Loop*> processed_loop;
585 
586   LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
587 
588   bool loop_changed = true;
589   while (loop_changed) {
590     loop_changed = false;
591     for (Loop& loop : make_range(
592              ++TreeDFIterator<Loop>(loop_descriptor.GetPlaceholderRootLoop()),
593              TreeDFIterator<Loop>())) {
594       if (processed_loop.count(&loop)) continue;
595       processed_loop.insert(&loop);
596 
597       LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor);
598       while (unswitcher.CanUnswitchLoop()) {
599         if (!loop.IsLCSSA()) {
600           LoopUtils(context(), &loop).MakeLoopClosedSSA();
601         }
602         modified = true;
603         loop_changed = true;
604         unswitcher.PerformUnswitch();
605       }
606       if (loop_changed) break;
607     }
608   }
609 
610   return modified;
611 }
612 
613 }  // namespace opt
614 }  // namespace spvtools
615