1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "code_sink.h"
16
17 #include <vector>
18
19 #include "source/opt/instruction.h"
20 #include "source/opt/ir_context.h"
21 #include "source/util/bit_vector.h"
22
23 namespace spvtools {
24 namespace opt {
25
Process()26 Pass::Status CodeSinkingPass::Process() {
27 bool modified = false;
28 for (Function& function : *get_module()) {
29 cfg()->ForEachBlockInPostOrder(function.entry().get(),
30 [&modified, this](BasicBlock* bb) {
31 if (SinkInstructionsInBB(bb)) {
32 modified = true;
33 }
34 });
35 }
36 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
37 }
38
SinkInstructionsInBB(BasicBlock * bb)39 bool CodeSinkingPass::SinkInstructionsInBB(BasicBlock* bb) {
40 bool modified = false;
41 for (auto inst = bb->rbegin(); inst != bb->rend(); ++inst) {
42 if (SinkInstruction(&*inst)) {
43 inst = bb->rbegin();
44 modified = true;
45 }
46 }
47 return modified;
48 }
49
SinkInstruction(Instruction * inst)50 bool CodeSinkingPass::SinkInstruction(Instruction* inst) {
51 if (inst->opcode() != spv::Op::OpLoad &&
52 inst->opcode() != spv::Op::OpAccessChain) {
53 return false;
54 }
55
56 if (ReferencesMutableMemory(inst)) {
57 return false;
58 }
59
60 if (BasicBlock* target_bb = FindNewBasicBlockFor(inst)) {
61 Instruction* pos = &*target_bb->begin();
62 while (pos->opcode() == spv::Op::OpPhi) {
63 pos = pos->NextNode();
64 }
65
66 inst->InsertBefore(pos);
67 context()->set_instr_block(inst, target_bb);
68 return true;
69 }
70 return false;
71 }
72
FindNewBasicBlockFor(Instruction * inst)73 BasicBlock* CodeSinkingPass::FindNewBasicBlockFor(Instruction* inst) {
74 assert(inst->result_id() != 0 && "Instruction should have a result.");
75 BasicBlock* original_bb = context()->get_instr_block(inst);
76 BasicBlock* bb = original_bb;
77
78 std::unordered_set<uint32_t> bbs_with_uses;
79 get_def_use_mgr()->ForEachUse(
80 inst, [&bbs_with_uses, this](Instruction* use, uint32_t idx) {
81 if (use->opcode() != spv::Op::OpPhi) {
82 BasicBlock* use_bb = context()->get_instr_block(use);
83 if (use_bb) {
84 bbs_with_uses.insert(use_bb->id());
85 }
86 } else {
87 bbs_with_uses.insert(use->GetSingleWordOperand(idx + 1));
88 }
89 });
90
91 while (true) {
92 // If |inst| is used in |bb|, then |inst| cannot be moved any further.
93 if (bbs_with_uses.count(bb->id())) {
94 break;
95 }
96
97 // If |bb| has one successor (succ_bb), and |bb| is the only predecessor
98 // of succ_bb, then |inst| can be moved to succ_bb. If succ_bb, has move
99 // then one predecessor, then moving |inst| into succ_bb could cause it to
100 // be executed more often, so the search has to stop.
101 if (bb->terminator()->opcode() == spv::Op::OpBranch) {
102 uint32_t succ_bb_id = bb->terminator()->GetSingleWordInOperand(0);
103 if (cfg()->preds(succ_bb_id).size() == 1) {
104 bb = context()->get_instr_block(succ_bb_id);
105 continue;
106 } else {
107 break;
108 }
109 }
110
111 // The remaining checks need to know the merge node. If there is no merge
112 // instruction or an OpLoopMerge, then it is a break or continue. We could
113 // figure it out, but not worth doing it now.
114 Instruction* merge_inst = bb->GetMergeInst();
115 if (merge_inst == nullptr ||
116 merge_inst->opcode() != spv::Op::OpSelectionMerge) {
117 break;
118 }
119
120 // Check all of the successors of |bb| it see which lead to a use of |inst|
121 // before reaching the merge node.
122 bool used_in_multiple_blocks = false;
123 uint32_t bb_used_in = 0;
124 bb->ForEachSuccessorLabel([this, bb, &bb_used_in, &used_in_multiple_blocks,
125 &bbs_with_uses](uint32_t* succ_bb_id) {
126 if (IntersectsPath(*succ_bb_id, bb->MergeBlockIdIfAny(), bbs_with_uses)) {
127 if (bb_used_in == 0) {
128 bb_used_in = *succ_bb_id;
129 } else {
130 used_in_multiple_blocks = true;
131 }
132 }
133 });
134
135 // If more than one successor, which is not the merge block, uses |inst|
136 // then we have to leave |inst| in bb because there is none of the
137 // successors dominate all uses of |inst|.
138 if (used_in_multiple_blocks) {
139 break;
140 }
141
142 if (bb_used_in == 0) {
143 // If |inst| is not used before reaching the merge node, then we can move
144 // |inst| to the merge node.
145 bb = context()->get_instr_block(bb->MergeBlockIdIfAny());
146 } else {
147 // If the only successor that leads to a used of |inst| has more than 1
148 // predecessor, then moving |inst| could cause it to be executed more
149 // often, so we cannot move it.
150 if (cfg()->preds(bb_used_in).size() != 1) {
151 break;
152 }
153
154 // If |inst| is used after the merge block, then |bb_used_in| does not
155 // dominate all of the uses. So we cannot move |inst| any further.
156 if (IntersectsPath(bb->MergeBlockIdIfAny(), original_bb->id(),
157 bbs_with_uses)) {
158 break;
159 }
160
161 // Otherwise, |bb_used_in| dominates all uses, so move |inst| into that
162 // block.
163 bb = context()->get_instr_block(bb_used_in);
164 }
165 continue;
166 }
167 return (bb != original_bb ? bb : nullptr);
168 }
169
ReferencesMutableMemory(Instruction * inst)170 bool CodeSinkingPass::ReferencesMutableMemory(Instruction* inst) {
171 if (!inst->IsLoad()) {
172 return false;
173 }
174
175 Instruction* base_ptr = inst->GetBaseAddress();
176 if (base_ptr->opcode() != spv::Op::OpVariable) {
177 return true;
178 }
179
180 if (base_ptr->IsReadOnlyPointer()) {
181 return false;
182 }
183
184 if (HasUniformMemorySync()) {
185 return true;
186 }
187
188 if (spv::StorageClass(base_ptr->GetSingleWordInOperand(0)) !=
189 spv::StorageClass::Uniform) {
190 return true;
191 }
192
193 return HasPossibleStore(base_ptr);
194 }
195
HasUniformMemorySync()196 bool CodeSinkingPass::HasUniformMemorySync() {
197 if (checked_for_uniform_sync_) {
198 return has_uniform_sync_;
199 }
200
201 bool has_sync = false;
202 get_module()->ForEachInst([this, &has_sync](Instruction* inst) {
203 switch (inst->opcode()) {
204 case spv::Op::OpMemoryBarrier: {
205 uint32_t mem_semantics_id = inst->GetSingleWordInOperand(1);
206 if (IsSyncOnUniform(mem_semantics_id)) {
207 has_sync = true;
208 }
209 break;
210 }
211 case spv::Op::OpControlBarrier:
212 case spv::Op::OpAtomicLoad:
213 case spv::Op::OpAtomicStore:
214 case spv::Op::OpAtomicExchange:
215 case spv::Op::OpAtomicIIncrement:
216 case spv::Op::OpAtomicIDecrement:
217 case spv::Op::OpAtomicIAdd:
218 case spv::Op::OpAtomicFAddEXT:
219 case spv::Op::OpAtomicISub:
220 case spv::Op::OpAtomicSMin:
221 case spv::Op::OpAtomicUMin:
222 case spv::Op::OpAtomicFMinEXT:
223 case spv::Op::OpAtomicSMax:
224 case spv::Op::OpAtomicUMax:
225 case spv::Op::OpAtomicFMaxEXT:
226 case spv::Op::OpAtomicAnd:
227 case spv::Op::OpAtomicOr:
228 case spv::Op::OpAtomicXor:
229 case spv::Op::OpAtomicFlagTestAndSet:
230 case spv::Op::OpAtomicFlagClear: {
231 uint32_t mem_semantics_id = inst->GetSingleWordInOperand(2);
232 if (IsSyncOnUniform(mem_semantics_id)) {
233 has_sync = true;
234 }
235 break;
236 }
237 case spv::Op::OpAtomicCompareExchange:
238 case spv::Op::OpAtomicCompareExchangeWeak:
239 if (IsSyncOnUniform(inst->GetSingleWordInOperand(2)) ||
240 IsSyncOnUniform(inst->GetSingleWordInOperand(3))) {
241 has_sync = true;
242 }
243 break;
244 default:
245 break;
246 }
247 });
248 has_uniform_sync_ = has_sync;
249 return has_sync;
250 }
251
IsSyncOnUniform(uint32_t mem_semantics_id) const252 bool CodeSinkingPass::IsSyncOnUniform(uint32_t mem_semantics_id) const {
253 const analysis::Constant* mem_semantics_const =
254 context()->get_constant_mgr()->FindDeclaredConstant(mem_semantics_id);
255 assert(mem_semantics_const != nullptr &&
256 "Expecting memory semantics id to be a constant.");
257 assert(mem_semantics_const->AsIntConstant() &&
258 "Memory semantics should be an integer.");
259 uint32_t mem_semantics_int = mem_semantics_const->GetU32();
260
261 // If it does not affect uniform memory, then it is does not apply to uniform
262 // memory.
263 if ((mem_semantics_int & uint32_t(spv::MemorySemanticsMask::UniformMemory)) ==
264 0) {
265 return false;
266 }
267
268 // Check if there is an acquire or release. If so not, this it does not add
269 // any memory constraints.
270 return (mem_semantics_int &
271 uint32_t(spv::MemorySemanticsMask::Acquire |
272 spv::MemorySemanticsMask::AcquireRelease |
273 spv::MemorySemanticsMask::Release)) != 0;
274 }
275
HasPossibleStore(Instruction * var_inst)276 bool CodeSinkingPass::HasPossibleStore(Instruction* var_inst) {
277 assert(var_inst->opcode() == spv::Op::OpVariable ||
278 var_inst->opcode() == spv::Op::OpAccessChain ||
279 var_inst->opcode() == spv::Op::OpPtrAccessChain);
280
281 return get_def_use_mgr()->WhileEachUser(var_inst, [this](Instruction* use) {
282 switch (use->opcode()) {
283 case spv::Op::OpStore:
284 return true;
285 case spv::Op::OpAccessChain:
286 case spv::Op::OpPtrAccessChain:
287 return HasPossibleStore(use);
288 default:
289 return false;
290 }
291 });
292 }
293
IntersectsPath(uint32_t start,uint32_t end,const std::unordered_set<uint32_t> & set)294 bool CodeSinkingPass::IntersectsPath(uint32_t start, uint32_t end,
295 const std::unordered_set<uint32_t>& set) {
296 std::vector<uint32_t> worklist;
297 worklist.push_back(start);
298 std::unordered_set<uint32_t> already_done;
299 already_done.insert(start);
300
301 while (!worklist.empty()) {
302 BasicBlock* bb = context()->get_instr_block(worklist.back());
303 worklist.pop_back();
304
305 if (bb->id() == end) {
306 continue;
307 }
308
309 if (set.count(bb->id())) {
310 return true;
311 }
312
313 bb->ForEachSuccessorLabel([&already_done, &worklist](uint32_t* succ_bb_id) {
314 if (already_done.insert(*succ_bb_id).second) {
315 worklist.push_back(*succ_bb_id);
316 }
317 });
318 }
319 return false;
320 }
321
322 // namespace opt
323
324 } // namespace opt
325 } // namespace spvtools
326