1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/cfg.h"
16
17 #include <memory>
18 #include <utility>
19
20 #include "source/cfa.h"
21 #include "source/opt/ir_builder.h"
22 #include "source/opt/ir_context.h"
23 #include "source/opt/module.h"
24
25 namespace spvtools {
26 namespace opt {
27 namespace {
28
29 using cbb_ptr = const opt::BasicBlock*;
30
31 // Universal Limit of ResultID + 1
32 const int kMaxResultId = 0x400000;
33
34 } // namespace
35
CFG(Module * module)36 CFG::CFG(Module* module)
37 : module_(module),
38 pseudo_entry_block_(std::unique_ptr<Instruction>(
39 new Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
40 pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
41 module->context(), SpvOpLabel, 0, kMaxResultId, {}))) {
42 for (auto& fn : *module) {
43 for (auto& blk : fn) {
44 RegisterBlock(&blk);
45 }
46 }
47 }
48
AddEdges(BasicBlock * blk)49 void CFG::AddEdges(BasicBlock* blk) {
50 uint32_t blk_id = blk->id();
51 // Force the creation of an entry, not all basic block have predecessors
52 // (such as the entry blocks and some unreachables).
53 label2preds_[blk_id];
54 const auto* const_blk = blk;
55 const_blk->ForEachSuccessorLabel(
56 [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
57 }
58
RemoveNonExistingEdges(uint32_t blk_id)59 void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
60 std::vector<uint32_t> updated_pred_list;
61 for (uint32_t id : preds(blk_id)) {
62 const BasicBlock* pred_blk = block(id);
63 bool has_branch = false;
64 pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
65 if (succ == blk_id) {
66 has_branch = true;
67 }
68 });
69 if (has_branch) updated_pred_list.push_back(id);
70 }
71
72 label2preds_.at(blk_id) = std::move(updated_pred_list);
73 }
74
ComputeStructuredOrder(Function * func,BasicBlock * root,std::list<BasicBlock * > * order)75 void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
76 std::list<BasicBlock*>* order) {
77 assert(module_->context()->get_feature_mgr()->HasCapability(
78 SpvCapabilityShader) &&
79 "This only works on structured control flow");
80
81 // Compute structured successors and do DFS.
82 ComputeStructuredSuccessors(func);
83 auto ignore_block = [](cbb_ptr) {};
84 auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
85 auto get_structured_successors = [this](const BasicBlock* b) {
86 return &(block2structured_succs_[b]);
87 };
88
89 // TODO(greg-lunarg): Get rid of const_cast by making moving const
90 // out of the cfa.h prototypes and into the invoking code.
91 auto post_order = [&](cbb_ptr b) {
92 order->push_front(const_cast<BasicBlock*>(b));
93 };
94 CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
95 ignore_block, post_order, ignore_edge);
96 }
97
ForEachBlockInPostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)98 void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
99 const std::function<void(BasicBlock*)>& f) {
100 std::vector<BasicBlock*> po;
101 std::unordered_set<BasicBlock*> seen;
102 ComputePostOrderTraversal(bb, &po, &seen);
103
104 for (BasicBlock* current_bb : po) {
105 if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
106 f(current_bb);
107 }
108 }
109 }
110
ForEachBlockInReversePostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)111 void CFG::ForEachBlockInReversePostOrder(
112 BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
113 std::vector<BasicBlock*> po;
114 std::unordered_set<BasicBlock*> seen;
115 ComputePostOrderTraversal(bb, &po, &seen);
116
117 for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
118 if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
119 f(*current_bb);
120 }
121 }
122 }
123
ComputeStructuredSuccessors(Function * func)124 void CFG::ComputeStructuredSuccessors(Function* func) {
125 block2structured_succs_.clear();
126 for (auto& blk : *func) {
127 // If no predecessors in function, make successor to pseudo entry.
128 if (label2preds_[blk.id()].size() == 0)
129 block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
130
131 // If header, make merge block first successor and continue block second
132 // successor if there is one.
133 uint32_t mbid = blk.MergeBlockIdIfAny();
134 if (mbid != 0) {
135 block2structured_succs_[&blk].push_back(block(mbid));
136 uint32_t cbid = blk.ContinueBlockIdIfAny();
137 if (cbid != 0) {
138 block2structured_succs_[&blk].push_back(block(cbid));
139 }
140 }
141
142 // Add true successors.
143 const auto& const_blk = blk;
144 const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
145 block2structured_succs_[&blk].push_back(block(sbid));
146 });
147 }
148 }
149
ComputePostOrderTraversal(BasicBlock * bb,std::vector<BasicBlock * > * order,std::unordered_set<BasicBlock * > * seen)150 void CFG::ComputePostOrderTraversal(BasicBlock* bb,
151 std::vector<BasicBlock*>* order,
152 std::unordered_set<BasicBlock*>* seen) {
153 seen->insert(bb);
154 static_cast<const BasicBlock*>(bb)->ForEachSuccessorLabel(
155 [&order, &seen, this](const uint32_t sbid) {
156 BasicBlock* succ_bb = id2block_[sbid];
157 if (!seen->count(succ_bb)) {
158 ComputePostOrderTraversal(succ_bb, order, seen);
159 }
160 });
161 order->push_back(bb);
162 }
163
SplitLoopHeader(BasicBlock * bb)164 BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
165 assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
166
167 Function* fn = bb->GetParent();
168 IRContext* context = module_->context();
169
170 // Get the new header id up front. If we are out of ids, then we cannot split
171 // the loop.
172 uint32_t new_header_id = context->TakeNextId();
173 if (new_header_id == 0) {
174 return nullptr;
175 }
176
177 // Find the insertion point for the new bb.
178 Function::iterator header_it = std::find_if(
179 fn->begin(), fn->end(),
180 [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
181 assert(header_it != fn->end());
182
183 const std::vector<uint32_t>& pred = preds(bb->id());
184 // Find the back edge
185 BasicBlock* latch_block = nullptr;
186 Function::iterator latch_block_iter = header_it;
187 while (++latch_block_iter != fn->end()) {
188 // If blocks are in the proper order, then the only branch that appears
189 // after the header is the latch.
190 if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
191 pred.end()) {
192 break;
193 }
194 }
195 assert(latch_block_iter != fn->end() && "Could not find the latch.");
196 latch_block = &*latch_block_iter;
197
198 RemoveSuccessorEdges(bb);
199
200 // Create the new header bb basic bb.
201 // Leave the phi instructions behind.
202 auto iter = bb->begin();
203 while (iter->opcode() == SpvOpPhi) {
204 ++iter;
205 }
206
207 BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
208 context->AnalyzeDefUse(new_header->GetLabelInst());
209
210 // Update cfg
211 RegisterBlock(new_header);
212
213 // Update bb mappings.
214 context->set_instr_block(new_header->GetLabelInst(), new_header);
215 new_header->ForEachInst([new_header, context](Instruction* inst) {
216 context->set_instr_block(inst, new_header);
217 });
218
219 // Adjust the OpPhi instructions as needed.
220 bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
221 std::vector<uint32_t> preheader_phi_ops;
222 std::vector<Operand> header_phi_ops;
223
224 // Identify where the original inputs to original OpPhi belong: header or
225 // preheader.
226 for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
227 uint32_t def_id = phi->GetSingleWordInOperand(i);
228 uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
229 if (branch_id == latch_block->id()) {
230 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
231 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
232 } else {
233 preheader_phi_ops.push_back(def_id);
234 preheader_phi_ops.push_back(branch_id);
235 }
236 }
237
238 // Create a phi instruction if and only if the preheader_phi_ops has more
239 // than one pair.
240 if (preheader_phi_ops.size() > 2) {
241 InstructionBuilder builder(
242 context, &*bb->begin(),
243 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
244
245 Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
246
247 // Add the OpPhi to the header bb.
248 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
249 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
250 } else {
251 // An OpPhi with a single entry is just a copy. In this case use the same
252 // instruction in the new header.
253 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
254 header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
255 }
256
257 phi->RemoveFromList();
258 std::unique_ptr<Instruction> phi_owner(phi);
259 phi->SetInOperands(std::move(header_phi_ops));
260 new_header->begin()->InsertBefore(std::move(phi_owner));
261 context->set_instr_block(phi, new_header);
262 context->AnalyzeUses(phi);
263 });
264
265 // Add a branch to the new header.
266 InstructionBuilder branch_builder(
267 context, bb,
268 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
269 bb->AddInstruction(
270 MakeUnique<Instruction>(context, SpvOpBranch, 0, 0,
271 std::initializer_list<Operand>{
272 {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
273 context->AnalyzeUses(bb->terminator());
274 context->set_instr_block(bb->terminator(), bb);
275 label2preds_[new_header->id()].push_back(bb->id());
276
277 // Update the latch to branch to the new header.
278 latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
279 if (*id == bb->id()) {
280 *id = new_header_id;
281 }
282 });
283 Instruction* latch_branch = latch_block->terminator();
284 context->AnalyzeUses(latch_branch);
285 label2preds_[new_header->id()].push_back(latch_block->id());
286
287 auto& block_preds = label2preds_[bb->id()];
288 auto latch_pos =
289 std::find(block_preds.begin(), block_preds.end(), latch_block->id());
290 assert(latch_pos != block_preds.end() && "The cfg was invalid.");
291 block_preds.erase(latch_pos);
292
293 // Update the loop descriptors
294 if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
295 LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
296 Loop* loop = (*loop_desc)[bb->id()];
297
298 loop->AddBasicBlock(new_header_id);
299 loop->SetHeaderBlock(new_header);
300 loop_desc->SetBasicBlockToLoop(new_header_id, loop);
301
302 loop->RemoveBasicBlock(bb->id());
303 loop->SetPreHeaderBlock(bb);
304
305 Loop* parent_loop = loop->GetParent();
306 if (parent_loop != nullptr) {
307 parent_loop->AddBasicBlock(bb->id());
308 loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
309 } else {
310 loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
311 }
312 }
313 return new_header;
314 }
315
316 } // namespace opt
317 } // namespace spvtools
318