1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/licm_pass.h"
16
17 #include <queue>
18
19 #include "source/opt/module.h"
20 #include "source/opt/pass.h"
21
22 namespace spvtools {
23 namespace opt {
24
Process()25 Pass::Status LICMPass::Process() { return ProcessIRContext(); }
26
ProcessIRContext()27 Pass::Status LICMPass::ProcessIRContext() {
28 Status status = Status::SuccessWithoutChange;
29 Module* module = get_module();
30
31 // Process each function in the module
32 for (auto func = module->begin();
33 func != module->end() && status != Status::Failure; ++func) {
34 status = CombineStatus(status, ProcessFunction(&*func));
35 }
36 return status;
37 }
38
ProcessFunction(Function * f)39 Pass::Status LICMPass::ProcessFunction(Function* f) {
40 Status status = Status::SuccessWithoutChange;
41 LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
42
43 // Process each loop in the function
44 for (auto it = loop_descriptor->begin();
45 it != loop_descriptor->end() && status != Status::Failure; ++it) {
46 Loop& loop = *it;
47 // Ignore nested loops, as we will process them in order in ProcessLoop
48 if (loop.IsNested()) {
49 continue;
50 }
51 status = CombineStatus(status, ProcessLoop(&loop, f));
52 }
53 return status;
54 }
55
ProcessLoop(Loop * loop,Function * f)56 Pass::Status LICMPass::ProcessLoop(Loop* loop, Function* f) {
57 Status status = Status::SuccessWithoutChange;
58
59 // Process all nested loops first
60 for (auto nl = loop->begin(); nl != loop->end() && status != Status::Failure;
61 ++nl) {
62 Loop* nested_loop = *nl;
63 status = CombineStatus(status, ProcessLoop(nested_loop, f));
64 }
65
66 std::vector<BasicBlock*> loop_bbs{};
67 status = CombineStatus(
68 status,
69 AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs));
70
71 for (size_t i = 0; i < loop_bbs.size() && status != Status::Failure; ++i) {
72 BasicBlock* bb = loop_bbs[i];
73 // do not delete the element
74 status =
75 CombineStatus(status, AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs));
76 }
77
78 return status;
79 }
80
AnalyseAndHoistFromBB(Loop * loop,Function * f,BasicBlock * bb,std::vector<BasicBlock * > * loop_bbs)81 Pass::Status LICMPass::AnalyseAndHoistFromBB(
82 Loop* loop, Function* f, BasicBlock* bb,
83 std::vector<BasicBlock*>* loop_bbs) {
84 bool modified = false;
85 std::function<bool(Instruction*)> hoist_inst =
86 [this, &loop, &modified](Instruction* inst) {
87 if (loop->ShouldHoistInstruction(*inst)) {
88 if (!HoistInstruction(loop, inst)) {
89 return false;
90 }
91 modified = true;
92 }
93 return true;
94 };
95
96 if (IsImmediatelyContainedInLoop(loop, f, bb)) {
97 if (!bb->WhileEachInst(hoist_inst, false)) {
98 return Status::Failure;
99 }
100 }
101
102 DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f);
103 DominatorTree& dom_tree = dom_analysis->GetDomTree();
104
105 for (DominatorTreeNode* child_dom_tree_node : *dom_tree.GetTreeNode(bb)) {
106 if (loop->IsInsideLoop(child_dom_tree_node->bb_)) {
107 loop_bbs->push_back(child_dom_tree_node->bb_);
108 }
109 }
110
111 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
112 }
113
IsImmediatelyContainedInLoop(Loop * loop,Function * f,BasicBlock * bb)114 bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f,
115 BasicBlock* bb) {
116 LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
117 return loop == (*loop_descriptor)[bb->id()];
118 }
119
HoistInstruction(Loop * loop,Instruction * inst)120 bool LICMPass::HoistInstruction(Loop* loop, Instruction* inst) {
121 // TODO(1841): Handle failure to create pre-header.
122 BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock();
123 if (!pre_header_bb) {
124 return false;
125 }
126 Instruction* insertion_point = &*pre_header_bb->tail();
127 Instruction* previous_node = insertion_point->PreviousNode();
128 if (previous_node && (previous_node->opcode() == spv::Op::OpLoopMerge ||
129 previous_node->opcode() == spv::Op::OpSelectionMerge)) {
130 insertion_point = previous_node;
131 }
132
133 inst->InsertBefore(insertion_point);
134 context()->set_instr_block(inst, pre_header_bb);
135 return true;
136 }
137
138 } // namespace opt
139 } // namespace spvtools
140