1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/cfg.h"
16
17 #include <memory>
18 #include <utility>
19
20 #include "source/cfa.h"
21 #include "source/opt/ir_builder.h"
22 #include "source/opt/ir_context.h"
23 #include "source/opt/module.h"
24
25 namespace spvtools {
26 namespace opt {
27 namespace {
28
29 using cbb_ptr = const opt::BasicBlock*;
30
31 // Universal Limit of ResultID + 1
32 constexpr int kMaxResultId = 0x400000;
33
34 } // namespace
35
CFG(Module * module)36 CFG::CFG(Module* module)
37 : module_(module),
38 pseudo_entry_block_(std::unique_ptr<Instruction>(
39 new Instruction(module->context(), spv::Op::OpLabel, 0, 0, {}))),
40 pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
41 module->context(), spv::Op::OpLabel, 0, kMaxResultId, {}))) {
42 for (auto& fn : *module) {
43 for (auto& blk : fn) {
44 RegisterBlock(&blk);
45 }
46 }
47 }
48
AddEdges(BasicBlock * blk)49 void CFG::AddEdges(BasicBlock* blk) {
50 uint32_t blk_id = blk->id();
51 // Force the creation of an entry, not all basic block have predecessors
52 // (such as the entry blocks and some unreachables).
53 label2preds_[blk_id];
54 const auto* const_blk = blk;
55 const_blk->ForEachSuccessorLabel(
56 [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
57 }
58
RemoveNonExistingEdges(uint32_t blk_id)59 void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
60 std::vector<uint32_t> updated_pred_list;
61 for (uint32_t id : preds(blk_id)) {
62 const BasicBlock* pred_blk = block(id);
63 bool has_branch = false;
64 pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
65 if (succ == blk_id) {
66 has_branch = true;
67 }
68 });
69 if (has_branch) updated_pred_list.push_back(id);
70 }
71
72 label2preds_.at(blk_id) = std::move(updated_pred_list);
73 }
74
ComputeStructuredOrder(Function * func,BasicBlock * root,std::list<BasicBlock * > * order)75 void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
76 std::list<BasicBlock*>* order) {
77 ComputeStructuredOrder(func, root, nullptr, order);
78 }
79
ComputeStructuredOrder(Function * func,BasicBlock * root,BasicBlock * end,std::list<BasicBlock * > * order)80 void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
81 BasicBlock* end,
82 std::list<BasicBlock*>* order) {
83 assert(module_->context()->get_feature_mgr()->HasCapability(
84 spv::Capability::Shader) &&
85 "This only works on structured control flow");
86
87 // Compute structured successors and do DFS.
88 ComputeStructuredSuccessors(func);
89 auto ignore_block = [](cbb_ptr) {};
90 auto terminal = [end](cbb_ptr bb) { return bb == end; };
91
92 auto get_structured_successors = [this](const BasicBlock* b) {
93 return &(block2structured_succs_[b]);
94 };
95
96 // TODO(greg-lunarg): Get rid of const_cast by making moving const
97 // out of the cfa.h prototypes and into the invoking code.
98 auto post_order = [&](cbb_ptr b) {
99 order->push_front(const_cast<BasicBlock*>(b));
100 };
101 CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
102 ignore_block, post_order, terminal);
103 }
104
ForEachBlockInPostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)105 void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
106 const std::function<void(BasicBlock*)>& f) {
107 std::vector<BasicBlock*> po;
108 std::unordered_set<BasicBlock*> seen;
109 ComputePostOrderTraversal(bb, &po, &seen);
110
111 for (BasicBlock* current_bb : po) {
112 if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
113 f(current_bb);
114 }
115 }
116 }
117
ForEachBlockInReversePostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)118 void CFG::ForEachBlockInReversePostOrder(
119 BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
120 WhileEachBlockInReversePostOrder(bb, [f](BasicBlock* b) {
121 f(b);
122 return true;
123 });
124 }
125
WhileEachBlockInReversePostOrder(BasicBlock * bb,const std::function<bool (BasicBlock *)> & f)126 bool CFG::WhileEachBlockInReversePostOrder(
127 BasicBlock* bb, const std::function<bool(BasicBlock*)>& f) {
128 std::vector<BasicBlock*> po;
129 std::unordered_set<BasicBlock*> seen;
130 ComputePostOrderTraversal(bb, &po, &seen);
131
132 for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
133 if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
134 if (!f(*current_bb)) {
135 return false;
136 }
137 }
138 }
139 return true;
140 }
141
ComputeStructuredSuccessors(Function * func)142 void CFG::ComputeStructuredSuccessors(Function* func) {
143 block2structured_succs_.clear();
144 for (auto& blk : *func) {
145 // If no predecessors in function, make successor to pseudo entry.
146 if (label2preds_[blk.id()].size() == 0)
147 block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
148
149 // If header, make merge block first successor and continue block second
150 // successor if there is one.
151 uint32_t mbid = blk.MergeBlockIdIfAny();
152 if (mbid != 0) {
153 block2structured_succs_[&blk].push_back(block(mbid));
154 uint32_t cbid = blk.ContinueBlockIdIfAny();
155 if (cbid != 0) {
156 block2structured_succs_[&blk].push_back(block(cbid));
157 }
158 }
159
160 // Add true successors.
161 const auto& const_blk = blk;
162 const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
163 block2structured_succs_[&blk].push_back(block(sbid));
164 });
165 }
166 }
167
ComputePostOrderTraversal(BasicBlock * bb,std::vector<BasicBlock * > * order,std::unordered_set<BasicBlock * > * seen)168 void CFG::ComputePostOrderTraversal(BasicBlock* bb,
169 std::vector<BasicBlock*>* order,
170 std::unordered_set<BasicBlock*>* seen) {
171 std::vector<BasicBlock*> stack;
172 stack.push_back(bb);
173 while (!stack.empty()) {
174 bb = stack.back();
175 seen->insert(bb);
176 static_cast<const BasicBlock*>(bb)->WhileEachSuccessorLabel(
177 [&seen, &stack, this](const uint32_t sbid) {
178 BasicBlock* succ_bb = id2block_[sbid];
179 if (!seen->count(succ_bb)) {
180 stack.push_back(succ_bb);
181 return false;
182 }
183 return true;
184 });
185 if (stack.back() == bb) {
186 order->push_back(bb);
187 stack.pop_back();
188 }
189 }
190 }
191
SplitLoopHeader(BasicBlock * bb)192 BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
193 assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
194
195 Function* fn = bb->GetParent();
196 IRContext* context = module_->context();
197
198 // Get the new header id up front. If we are out of ids, then we cannot split
199 // the loop.
200 uint32_t new_header_id = context->TakeNextId();
201 if (new_header_id == 0) {
202 return nullptr;
203 }
204
205 // Find the insertion point for the new bb.
206 Function::iterator header_it = std::find_if(
207 fn->begin(), fn->end(),
208 [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
209 assert(header_it != fn->end());
210
211 const std::vector<uint32_t>& pred = preds(bb->id());
212 // Find the back edge
213 BasicBlock* latch_block = nullptr;
214 Function::iterator latch_block_iter = header_it;
215 for (; latch_block_iter != fn->end(); ++latch_block_iter) {
216 // If blocks are in the proper order, then the only branch that appears
217 // after the header is the latch.
218 if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
219 pred.end()) {
220 break;
221 }
222 }
223 assert(latch_block_iter != fn->end() && "Could not find the latch.");
224 latch_block = &*latch_block_iter;
225
226 RemoveSuccessorEdges(bb);
227
228 // Create the new header bb basic bb.
229 // Leave the phi instructions behind.
230 auto iter = bb->begin();
231 while (iter->opcode() == spv::Op::OpPhi) {
232 ++iter;
233 }
234
235 BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
236 context->AnalyzeDefUse(new_header->GetLabelInst());
237
238 // Update cfg
239 RegisterBlock(new_header);
240
241 // Update bb mappings.
242 context->set_instr_block(new_header->GetLabelInst(), new_header);
243 new_header->ForEachInst([new_header, context](Instruction* inst) {
244 context->set_instr_block(inst, new_header);
245 });
246
247 // If |bb| was the latch block, the branch back to the header is not in
248 // |new_header|.
249 if (latch_block == bb) {
250 if (new_header->ContinueBlockId() == bb->id()) {
251 new_header->GetLoopMergeInst()->SetInOperand(1, {new_header_id});
252 }
253 latch_block = new_header;
254 }
255
256 // Adjust the OpPhi instructions as needed.
257 bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
258 std::vector<uint32_t> preheader_phi_ops;
259 std::vector<Operand> header_phi_ops;
260
261 // Identify where the original inputs to original OpPhi belong: header or
262 // preheader.
263 for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
264 uint32_t def_id = phi->GetSingleWordInOperand(i);
265 uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
266 if (branch_id == latch_block->id()) {
267 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
268 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
269 } else {
270 preheader_phi_ops.push_back(def_id);
271 preheader_phi_ops.push_back(branch_id);
272 }
273 }
274
275 // Create a phi instruction if and only if the preheader_phi_ops has more
276 // than one pair.
277 if (preheader_phi_ops.size() > 2) {
278 InstructionBuilder builder(
279 context, &*bb->begin(),
280 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
281
282 Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
283
284 // Add the OpPhi to the header bb.
285 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
286 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
287 } else {
288 // An OpPhi with a single entry is just a copy. In this case use the same
289 // instruction in the new header.
290 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
291 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
292 }
293
294 phi->RemoveFromList();
295 std::unique_ptr<Instruction> phi_owner(phi);
296 phi->SetInOperands(std::move(header_phi_ops));
297 new_header->begin()->InsertBefore(std::move(phi_owner));
298 context->set_instr_block(phi, new_header);
299 context->AnalyzeUses(phi);
300 });
301
302 // Add a branch to the new header.
303 InstructionBuilder branch_builder(
304 context, bb,
305 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
306 bb->AddInstruction(
307 MakeUnique<Instruction>(context, spv::Op::OpBranch, 0, 0,
308 std::initializer_list<Operand>{
309 {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
310 context->AnalyzeUses(bb->terminator());
311 context->set_instr_block(bb->terminator(), bb);
312 label2preds_[new_header->id()].push_back(bb->id());
313
314 // Update the latch to branch to the new header.
315 latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
316 if (*id == bb->id()) {
317 *id = new_header_id;
318 }
319 });
320 Instruction* latch_branch = latch_block->terminator();
321 context->AnalyzeUses(latch_branch);
322 label2preds_[new_header->id()].push_back(latch_block->id());
323
324 auto& block_preds = label2preds_[bb->id()];
325 auto latch_pos =
326 std::find(block_preds.begin(), block_preds.end(), latch_block->id());
327 assert(latch_pos != block_preds.end() && "The cfg was invalid.");
328 block_preds.erase(latch_pos);
329
330 // Update the loop descriptors
331 if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
332 LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
333 Loop* loop = (*loop_desc)[bb->id()];
334
335 loop->AddBasicBlock(new_header_id);
336 loop->SetHeaderBlock(new_header);
337 loop_desc->SetBasicBlockToLoop(new_header_id, loop);
338
339 loop->RemoveBasicBlock(bb->id());
340 loop->SetPreHeaderBlock(bb);
341
342 Loop* parent_loop = loop->GetParent();
343 if (parent_loop != nullptr) {
344 parent_loop->AddBasicBlock(bb->id());
345 loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
346 } else {
347 loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
348 }
349 }
350 return new_header;
351 }
352
353 } // namespace opt
354 } // namespace spvtools
355