1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/reduce/remove_selection_reduction_opportunity_finder.h"
16
17 #include "source/reduce/remove_selection_reduction_opportunity.h"
18
19 namespace spvtools {
20 namespace reduce {
21
22 namespace {
23 const uint32_t kMergeNodeIndex = 0;
24 const uint32_t kContinueNodeIndex = 1;
25 } // namespace
26
GetName() const27 std::string RemoveSelectionReductionOpportunityFinder::GetName() const {
28 return "RemoveSelectionReductionOpportunityFinder";
29 }
30
31 std::vector<std::unique_ptr<ReductionOpportunity>>
GetAvailableOpportunities(opt::IRContext * context,uint32_t target_function) const32 RemoveSelectionReductionOpportunityFinder::GetAvailableOpportunities(
33 opt::IRContext* context, uint32_t target_function) const {
34 // Get all loop merge and continue blocks so we can check for these later.
35 std::unordered_set<uint32_t> merge_and_continue_blocks_from_loops;
36 for (auto* function : GetTargetFunctions(context, target_function)) {
37 for (auto& block : *function) {
38 if (auto merge_instruction = block.GetMergeInst()) {
39 if (merge_instruction->opcode() == SpvOpLoopMerge) {
40 uint32_t merge_block_id =
41 merge_instruction->GetSingleWordOperand(kMergeNodeIndex);
42 uint32_t continue_block_id =
43 merge_instruction->GetSingleWordOperand(kContinueNodeIndex);
44 merge_and_continue_blocks_from_loops.insert(merge_block_id);
45 merge_and_continue_blocks_from_loops.insert(continue_block_id);
46 }
47 }
48 }
49 }
50
51 // Return all selection headers where the OpSelectionMergeInstruction can be
52 // removed.
53 std::vector<std::unique_ptr<ReductionOpportunity>> result;
54 for (auto& function : *context->module()) {
55 for (auto& block : function) {
56 if (auto merge_instruction = block.GetMergeInst()) {
57 if (merge_instruction->opcode() == SpvOpSelectionMerge) {
58 if (CanOpSelectionMergeBeRemoved(
59 context, block, merge_instruction,
60 merge_and_continue_blocks_from_loops)) {
61 result.push_back(
62 MakeUnique<RemoveSelectionReductionOpportunity>(&block));
63 }
64 }
65 }
66 }
67 }
68 return result;
69 }
70
CanOpSelectionMergeBeRemoved(opt::IRContext * context,const opt::BasicBlock & header_block,opt::Instruction * merge_instruction,std::unordered_set<uint32_t> merge_and_continue_blocks_from_loops)71 bool RemoveSelectionReductionOpportunityFinder::CanOpSelectionMergeBeRemoved(
72 opt::IRContext* context, const opt::BasicBlock& header_block,
73 opt::Instruction* merge_instruction,
74 std::unordered_set<uint32_t> merge_and_continue_blocks_from_loops) {
75 assert(header_block.GetMergeInst() == merge_instruction &&
76 "CanOpSelectionMergeBeRemoved(...): header block and merge "
77 "instruction mismatch");
78
79 // The OpSelectionMerge instruction is needed if either of the following are
80 // true.
81 //
82 // 1. The header block has at least two (unique) successors that are not
83 // merge or continue blocks of a loop.
84 //
85 // 2. The predecessors of the merge block are "using" the merge block to avoid
86 // divergence. In other words, there exists a predecessor of the merge block
87 // that has a successor that is not the merge block of this construct and not
88 // a merge or continue block of a loop.
89
90 // 1.
91 {
92 uint32_t divergent_successor_count = 0;
93
94 std::unordered_set<uint32_t> seen_successors;
95
96 header_block.ForEachSuccessorLabel(
97 [&seen_successors, &merge_and_continue_blocks_from_loops,
98 &divergent_successor_count](uint32_t successor) {
99 // Not already seen.
100 if (seen_successors.find(successor) == seen_successors.end()) {
101 seen_successors.insert(successor);
102 // Not a loop continue or merge.
103 if (merge_and_continue_blocks_from_loops.find(successor) ==
104 merge_and_continue_blocks_from_loops.end()) {
105 ++divergent_successor_count;
106 }
107 }
108 });
109
110 if (divergent_successor_count > 1) {
111 return false;
112 }
113 }
114
115 // 2.
116 {
117 uint32_t merge_block_id =
118 merge_instruction->GetSingleWordOperand(kMergeNodeIndex);
119 for (uint32_t predecessor_block_id :
120 context->cfg()->preds(merge_block_id)) {
121 const opt::BasicBlock* predecessor_block =
122 context->cfg()->block(predecessor_block_id);
123 assert(predecessor_block);
124 bool found_divergent_successor = false;
125 predecessor_block->ForEachSuccessorLabel(
126 [&found_divergent_successor, merge_block_id,
127 &merge_and_continue_blocks_from_loops](uint32_t successor_id) {
128 // The successor is not the merge block, nor a loop merge or
129 // continue.
130 if (successor_id != merge_block_id &&
131 merge_and_continue_blocks_from_loops.find(successor_id) ==
132 merge_and_continue_blocks_from_loops.end()) {
133 found_divergent_successor = true;
134 }
135 });
136 if (found_divergent_successor) {
137 return false;
138 }
139 }
140 }
141
142 return true;
143 }
144
145 } // namespace reduce
146 } // namespace spvtools
147