• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/reduce/remove_selection_reduction_opportunity_finder.h"
16 
17 #include "source/reduce/remove_selection_reduction_opportunity.h"
18 
19 namespace spvtools {
20 namespace reduce {
21 
22 namespace {
23 const uint32_t kMergeNodeIndex = 0;
24 const uint32_t kContinueNodeIndex = 1;
25 }  // namespace
26 
GetName() const27 std::string RemoveSelectionReductionOpportunityFinder::GetName() const {
28   return "RemoveSelectionReductionOpportunityFinder";
29 }
30 
31 std::vector<std::unique_ptr<ReductionOpportunity>>
GetAvailableOpportunities(opt::IRContext * context,uint32_t target_function) const32 RemoveSelectionReductionOpportunityFinder::GetAvailableOpportunities(
33     opt::IRContext* context, uint32_t target_function) const {
34   // Get all loop merge and continue blocks so we can check for these later.
35   std::unordered_set<uint32_t> merge_and_continue_blocks_from_loops;
36   for (auto* function : GetTargetFunctions(context, target_function)) {
37     for (auto& block : *function) {
38       if (auto merge_instruction = block.GetMergeInst()) {
39         if (merge_instruction->opcode() == SpvOpLoopMerge) {
40           uint32_t merge_block_id =
41               merge_instruction->GetSingleWordOperand(kMergeNodeIndex);
42           uint32_t continue_block_id =
43               merge_instruction->GetSingleWordOperand(kContinueNodeIndex);
44           merge_and_continue_blocks_from_loops.insert(merge_block_id);
45           merge_and_continue_blocks_from_loops.insert(continue_block_id);
46         }
47       }
48     }
49   }
50 
51   // Return all selection headers where the OpSelectionMergeInstruction can be
52   // removed.
53   std::vector<std::unique_ptr<ReductionOpportunity>> result;
54   for (auto& function : *context->module()) {
55     for (auto& block : function) {
56       if (auto merge_instruction = block.GetMergeInst()) {
57         if (merge_instruction->opcode() == SpvOpSelectionMerge) {
58           if (CanOpSelectionMergeBeRemoved(
59                   context, block, merge_instruction,
60                   merge_and_continue_blocks_from_loops)) {
61             result.push_back(
62                 MakeUnique<RemoveSelectionReductionOpportunity>(&block));
63           }
64         }
65       }
66     }
67   }
68   return result;
69 }
70 
CanOpSelectionMergeBeRemoved(opt::IRContext * context,const opt::BasicBlock & header_block,opt::Instruction * merge_instruction,std::unordered_set<uint32_t> merge_and_continue_blocks_from_loops)71 bool RemoveSelectionReductionOpportunityFinder::CanOpSelectionMergeBeRemoved(
72     opt::IRContext* context, const opt::BasicBlock& header_block,
73     opt::Instruction* merge_instruction,
74     std::unordered_set<uint32_t> merge_and_continue_blocks_from_loops) {
75   assert(header_block.GetMergeInst() == merge_instruction &&
76          "CanOpSelectionMergeBeRemoved(...): header block and merge "
77          "instruction mismatch");
78 
79   // The OpSelectionMerge instruction is needed if either of the following are
80   // true.
81   //
82   // 1. The header block has at least two (unique) successors that are not
83   // merge or continue blocks of a loop.
84   //
85   // 2. The predecessors of the merge block are "using" the merge block to avoid
86   // divergence. In other words, there exists a predecessor of the merge block
87   // that has a successor that is not the merge block of this construct and not
88   // a merge or continue block of a loop.
89 
90   // 1.
91   {
92     uint32_t divergent_successor_count = 0;
93 
94     std::unordered_set<uint32_t> seen_successors;
95 
96     header_block.ForEachSuccessorLabel(
97         [&seen_successors, &merge_and_continue_blocks_from_loops,
98          &divergent_successor_count](uint32_t successor) {
99           // Not already seen.
100           if (seen_successors.find(successor) == seen_successors.end()) {
101             seen_successors.insert(successor);
102             // Not a loop continue or merge.
103             if (merge_and_continue_blocks_from_loops.find(successor) ==
104                 merge_and_continue_blocks_from_loops.end()) {
105               ++divergent_successor_count;
106             }
107           }
108         });
109 
110     if (divergent_successor_count > 1) {
111       return false;
112     }
113   }
114 
115   // 2.
116   {
117     uint32_t merge_block_id =
118         merge_instruction->GetSingleWordOperand(kMergeNodeIndex);
119     for (uint32_t predecessor_block_id :
120          context->cfg()->preds(merge_block_id)) {
121       const opt::BasicBlock* predecessor_block =
122           context->cfg()->block(predecessor_block_id);
123       assert(predecessor_block);
124       bool found_divergent_successor = false;
125       predecessor_block->ForEachSuccessorLabel(
126           [&found_divergent_successor, merge_block_id,
127            &merge_and_continue_blocks_from_loops](uint32_t successor_id) {
128             // The successor is not the merge block, nor a loop merge or
129             // continue.
130             if (successor_id != merge_block_id &&
131                 merge_and_continue_blocks_from_loops.find(successor_id) ==
132                     merge_and_continue_blocks_from_loops.end()) {
133               found_divergent_successor = true;
134             }
135           });
136       if (found_divergent_successor) {
137         return false;
138       }
139     }
140   }
141 
142   return true;
143 }
144 
145 }  // namespace reduce
146 }  // namespace spvtools
147