1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "source/opt/dead_insert_elim_pass.h"
18
19 #include "source/opt/composite.h"
20 #include "source/opt/ir_context.h"
21 #include "source/opt/iterator.h"
22 #include "spirv/1.2/GLSL.std.450.h"
23
24 namespace spvtools {
25 namespace opt {
26
27 namespace {
28
29 const uint32_t kTypeVectorCountInIdx = 1;
30 const uint32_t kTypeMatrixCountInIdx = 1;
31 const uint32_t kTypeArrayLengthIdInIdx = 1;
32 const uint32_t kTypeIntWidthInIdx = 0;
33 const uint32_t kConstantValueInIdx = 0;
34 const uint32_t kInsertObjectIdInIdx = 0;
35 const uint32_t kInsertCompositeIdInIdx = 1;
36
37 } // anonymous namespace
38
NumComponents(Instruction * typeInst)39 uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) {
40 switch (typeInst->opcode()) {
41 case SpvOpTypeVector: {
42 return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx);
43 } break;
44 case SpvOpTypeMatrix: {
45 return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx);
46 } break;
47 case SpvOpTypeArray: {
48 uint32_t lenId =
49 typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx);
50 Instruction* lenInst = get_def_use_mgr()->GetDef(lenId);
51 if (lenInst->opcode() != SpvOpConstant) return 0;
52 uint32_t lenTypeId = lenInst->type_id();
53 Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId);
54 // TODO(greg-lunarg): Support non-32-bit array length
55 if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
56 return 0;
57 return lenInst->GetSingleWordInOperand(kConstantValueInIdx);
58 } break;
59 case SpvOpTypeStruct: {
60 return typeInst->NumInOperands();
61 } break;
62 default: { return 0; } break;
63 }
64 }
65
MarkInsertChain(Instruction * insertChain,std::vector<uint32_t> * pExtIndices,uint32_t extOffset,std::unordered_set<uint32_t> * visited_phis)66 void DeadInsertElimPass::MarkInsertChain(
67 Instruction* insertChain, std::vector<uint32_t>* pExtIndices,
68 uint32_t extOffset, std::unordered_set<uint32_t>* visited_phis) {
69 // Not currently optimizing array inserts.
70 Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id());
71 if (typeInst->opcode() == SpvOpTypeArray) return;
72 // Insert chains are only composed of inserts and phis
73 if (insertChain->opcode() != SpvOpCompositeInsert &&
74 insertChain->opcode() != SpvOpPhi)
75 return;
76 // If extract indices are empty, mark all subcomponents if type
77 // is constant length.
78 if (pExtIndices == nullptr) {
79 uint32_t cnum = NumComponents(typeInst);
80 if (cnum > 0) {
81 std::vector<uint32_t> extIndices;
82 for (uint32_t i = 0; i < cnum; i++) {
83 extIndices.clear();
84 extIndices.push_back(i);
85 std::unordered_set<uint32_t> sub_visited_phis;
86 MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis);
87 }
88 return;
89 }
90 }
91 Instruction* insInst = insertChain;
92 while (insInst->opcode() == SpvOpCompositeInsert) {
93 // If no extract indices, mark insert and inserted object (which might
94 // also be an insert chain) and continue up the chain though the input
95 // composite.
96 //
97 // Note: We mark inserted objects in this function (rather than in
98 // EliminateDeadInsertsOnePass) because in some cases, we can do it
99 // more accurately here.
100 if (pExtIndices == nullptr) {
101 liveInserts_.insert(insInst->result_id());
102 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
103 std::unordered_set<uint32_t> obj_visited_phis;
104 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
105 &obj_visited_phis);
106 // If extract indices match insert, we are done. Mark insert and
107 // inserted object.
108 } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) {
109 liveInserts_.insert(insInst->result_id());
110 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
111 std::unordered_set<uint32_t> obj_visited_phis;
112 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
113 &obj_visited_phis);
114 break;
115 // If non-matching intersection, mark insert
116 } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) {
117 liveInserts_.insert(insInst->result_id());
118 // If more extract indices than insert, we are done. Use remaining
119 // extract indices to mark inserted object.
120 uint32_t numInsertIndices = insInst->NumInOperands() - 2;
121 if (pExtIndices->size() - extOffset > numInsertIndices) {
122 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
123 std::unordered_set<uint32_t> obj_visited_phis;
124 MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices,
125 extOffset + numInsertIndices, &obj_visited_phis);
126 break;
127 // If fewer extract indices than insert, also mark inserted object and
128 // continue up chain.
129 } else {
130 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
131 std::unordered_set<uint32_t> obj_visited_phis;
132 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
133 &obj_visited_phis);
134 }
135 }
136 // Get next insert in chain
137 const uint32_t compId =
138 insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
139 insInst = get_def_use_mgr()->GetDef(compId);
140 }
141 // If insert chain ended with phi, do recursive call on each operand
142 if (insInst->opcode() != SpvOpPhi) return;
143 // Mark phi visited to prevent potential infinite loop. If phi is already
144 // visited, return to avoid infinite loop.
145 if (visited_phis->count(insInst->result_id()) != 0) return;
146 visited_phis->insert(insInst->result_id());
147
148 // Phis may have duplicate inputs values for different edges, prune incoming
149 // ids lists before recursing.
150 std::vector<uint32_t> ids;
151 for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) {
152 ids.push_back(insInst->GetSingleWordInOperand(i));
153 }
154 std::sort(ids.begin(), ids.end());
155 auto new_end = std::unique(ids.begin(), ids.end());
156 for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) {
157 Instruction* pi = get_def_use_mgr()->GetDef(*id_iter);
158 MarkInsertChain(pi, pExtIndices, extOffset, visited_phis);
159 }
160 }
161
EliminateDeadInserts(Function * func)162 bool DeadInsertElimPass::EliminateDeadInserts(Function* func) {
163 bool modified = false;
164 bool lastmodified = true;
165 // Each pass can delete dead instructions, thus potentially revealing
166 // new dead insertions ie insertions with no uses.
167 while (lastmodified) {
168 lastmodified = EliminateDeadInsertsOnePass(func);
169 modified |= lastmodified;
170 }
171 return modified;
172 }
173
EliminateDeadInsertsOnePass(Function * func)174 bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) {
175 bool modified = false;
176 liveInserts_.clear();
177 visitedPhis_.clear();
178 // Mark all live inserts
179 for (auto bi = func->begin(); bi != func->end(); ++bi) {
180 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
181 // Only process Inserts and composite Phis
182 SpvOp op = ii->opcode();
183 Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id());
184 if (op != SpvOpCompositeInsert &&
185 (op != SpvOpPhi || !spvOpcodeIsComposite(typeInst->opcode())))
186 continue;
187 // The marking algorithm can be expensive for large arrays and the
188 // efficacy of eliminating dead inserts into arrays is questionable.
189 // Skip optimizing array inserts for now. Just mark them live.
190 // TODO(greg-lunarg): Eliminate dead array inserts
191 if (op == SpvOpCompositeInsert) {
192 if (typeInst->opcode() == SpvOpTypeArray) {
193 liveInserts_.insert(ii->result_id());
194 continue;
195 }
196 }
197 const uint32_t id = ii->result_id();
198 get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) {
199 if (user->IsCommonDebugInstr()) return;
200 switch (user->opcode()) {
201 case SpvOpCompositeInsert:
202 case SpvOpPhi:
203 // Use by insert or phi does not initiate marking
204 break;
205 case SpvOpCompositeExtract: {
206 // Capture extract indices
207 std::vector<uint32_t> extIndices;
208 uint32_t icnt = 0;
209 user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) {
210 if (icnt > 0) extIndices.push_back(*idp);
211 ++icnt;
212 });
213 // Mark all inserts in chain that intersect with extract
214 std::unordered_set<uint32_t> visited_phis;
215 MarkInsertChain(&*ii, &extIndices, 0, &visited_phis);
216 } break;
217 default: {
218 // Mark inserts in chain for all components
219 MarkInsertChain(&*ii, nullptr, 0, nullptr);
220 } break;
221 }
222 });
223 }
224 }
225 // Find and disconnect dead inserts
226 std::vector<Instruction*> dead_instructions;
227 for (auto bi = func->begin(); bi != func->end(); ++bi) {
228 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
229 if (ii->opcode() != SpvOpCompositeInsert) continue;
230 const uint32_t id = ii->result_id();
231 if (liveInserts_.find(id) != liveInserts_.end()) continue;
232 const uint32_t replId =
233 ii->GetSingleWordInOperand(kInsertCompositeIdInIdx);
234 (void)context()->ReplaceAllUsesWith(id, replId);
235 dead_instructions.push_back(&*ii);
236 modified = true;
237 }
238 }
239 // DCE dead inserts
240 while (!dead_instructions.empty()) {
241 Instruction* inst = dead_instructions.back();
242 dead_instructions.pop_back();
243 DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
244 auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
245 other_inst);
246 if (i != dead_instructions.end()) {
247 dead_instructions.erase(i);
248 }
249 });
250 }
251 return modified;
252 }
253
Process()254 Pass::Status DeadInsertElimPass::Process() {
255 // Process all entry point functions.
256 ProcessFunction pfn = [this](Function* fp) {
257 return EliminateDeadInserts(fp);
258 };
259 bool modified = context()->ProcessReachableCallTree(pfn);
260 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
261 }
262
263 } // namespace opt
264 } // namespace spvtools
265