1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "source/opt/dead_insert_elim_pass.h"
18
19 #include "source/opt/composite.h"
20 #include "source/opt/ir_context.h"
21 #include "source/opt/iterator.h"
22 #include "spirv/1.2/GLSL.std.450.h"
23
24 namespace spvtools {
25 namespace opt {
26 namespace {
27 constexpr uint32_t kTypeVectorCountInIdx = 1;
28 constexpr uint32_t kTypeMatrixCountInIdx = 1;
29 constexpr uint32_t kTypeArrayLengthIdInIdx = 1;
30 constexpr uint32_t kTypeIntWidthInIdx = 0;
31 constexpr uint32_t kConstantValueInIdx = 0;
32 constexpr uint32_t kInsertObjectIdInIdx = 0;
33 constexpr uint32_t kInsertCompositeIdInIdx = 1;
34 } // namespace
35
NumComponents(Instruction * typeInst)36 uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) {
37 switch (typeInst->opcode()) {
38 case spv::Op::OpTypeVector: {
39 return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx);
40 } break;
41 case spv::Op::OpTypeMatrix: {
42 return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx);
43 } break;
44 case spv::Op::OpTypeArray: {
45 uint32_t lenId =
46 typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx);
47 Instruction* lenInst = get_def_use_mgr()->GetDef(lenId);
48 if (lenInst->opcode() != spv::Op::OpConstant) return 0;
49 uint32_t lenTypeId = lenInst->type_id();
50 Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId);
51 // TODO(greg-lunarg): Support non-32-bit array length
52 if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
53 return 0;
54 return lenInst->GetSingleWordInOperand(kConstantValueInIdx);
55 } break;
56 case spv::Op::OpTypeStruct: {
57 return typeInst->NumInOperands();
58 } break;
59 default: { return 0; } break;
60 }
61 }
62
MarkInsertChain(Instruction * insertChain,std::vector<uint32_t> * pExtIndices,uint32_t extOffset,std::unordered_set<uint32_t> * visited_phis)63 void DeadInsertElimPass::MarkInsertChain(
64 Instruction* insertChain, std::vector<uint32_t>* pExtIndices,
65 uint32_t extOffset, std::unordered_set<uint32_t>* visited_phis) {
66 // Not currently optimizing array inserts.
67 Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id());
68 if (typeInst->opcode() == spv::Op::OpTypeArray) return;
69 // Insert chains are only composed of inserts and phis
70 if (insertChain->opcode() != spv::Op::OpCompositeInsert &&
71 insertChain->opcode() != spv::Op::OpPhi)
72 return;
73 // If extract indices are empty, mark all subcomponents if type
74 // is constant length.
75 if (pExtIndices == nullptr) {
76 uint32_t cnum = NumComponents(typeInst);
77 if (cnum > 0) {
78 std::vector<uint32_t> extIndices;
79 for (uint32_t i = 0; i < cnum; i++) {
80 extIndices.clear();
81 extIndices.push_back(i);
82 std::unordered_set<uint32_t> sub_visited_phis;
83 MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis);
84 }
85 return;
86 }
87 }
88 Instruction* insInst = insertChain;
89 while (insInst->opcode() == spv::Op::OpCompositeInsert) {
90 // If no extract indices, mark insert and inserted object (which might
91 // also be an insert chain) and continue up the chain though the input
92 // composite.
93 //
94 // Note: We mark inserted objects in this function (rather than in
95 // EliminateDeadInsertsOnePass) because in some cases, we can do it
96 // more accurately here.
97 if (pExtIndices == nullptr) {
98 liveInserts_.insert(insInst->result_id());
99 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
100 std::unordered_set<uint32_t> obj_visited_phis;
101 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
102 &obj_visited_phis);
103 // If extract indices match insert, we are done. Mark insert and
104 // inserted object.
105 } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) {
106 liveInserts_.insert(insInst->result_id());
107 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
108 std::unordered_set<uint32_t> obj_visited_phis;
109 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
110 &obj_visited_phis);
111 break;
112 // If non-matching intersection, mark insert
113 } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) {
114 liveInserts_.insert(insInst->result_id());
115 // If more extract indices than insert, we are done. Use remaining
116 // extract indices to mark inserted object.
117 uint32_t numInsertIndices = insInst->NumInOperands() - 2;
118 if (pExtIndices->size() - extOffset > numInsertIndices) {
119 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
120 std::unordered_set<uint32_t> obj_visited_phis;
121 MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices,
122 extOffset + numInsertIndices, &obj_visited_phis);
123 break;
124 // If fewer extract indices than insert, also mark inserted object and
125 // continue up chain.
126 } else {
127 uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
128 std::unordered_set<uint32_t> obj_visited_phis;
129 MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
130 &obj_visited_phis);
131 }
132 }
133 // Get next insert in chain
134 const uint32_t compId =
135 insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
136 insInst = get_def_use_mgr()->GetDef(compId);
137 }
138 // If insert chain ended with phi, do recursive call on each operand
139 if (insInst->opcode() != spv::Op::OpPhi) return;
140 // Mark phi visited to prevent potential infinite loop. If phi is already
141 // visited, return to avoid infinite loop.
142 if (visited_phis->count(insInst->result_id()) != 0) return;
143 visited_phis->insert(insInst->result_id());
144
145 // Phis may have duplicate inputs values for different edges, prune incoming
146 // ids lists before recursing.
147 std::vector<uint32_t> ids;
148 for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) {
149 ids.push_back(insInst->GetSingleWordInOperand(i));
150 }
151 std::sort(ids.begin(), ids.end());
152 auto new_end = std::unique(ids.begin(), ids.end());
153 for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) {
154 Instruction* pi = get_def_use_mgr()->GetDef(*id_iter);
155 MarkInsertChain(pi, pExtIndices, extOffset, visited_phis);
156 }
157 }
158
EliminateDeadInserts(Function * func)159 bool DeadInsertElimPass::EliminateDeadInserts(Function* func) {
160 bool modified = false;
161 bool lastmodified = true;
162 // Each pass can delete dead instructions, thus potentially revealing
163 // new dead insertions ie insertions with no uses.
164 while (lastmodified) {
165 lastmodified = EliminateDeadInsertsOnePass(func);
166 modified |= lastmodified;
167 }
168 return modified;
169 }
170
EliminateDeadInsertsOnePass(Function * func)171 bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) {
172 bool modified = false;
173 liveInserts_.clear();
174 visitedPhis_.clear();
175 // Mark all live inserts
176 for (auto bi = func->begin(); bi != func->end(); ++bi) {
177 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
178 // Only process Inserts and composite Phis
179 spv::Op op = ii->opcode();
180 Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id());
181 if (op != spv::Op::OpCompositeInsert &&
182 (op != spv::Op::OpPhi || !spvOpcodeIsComposite(typeInst->opcode())))
183 continue;
184 // The marking algorithm can be expensive for large arrays and the
185 // efficacy of eliminating dead inserts into arrays is questionable.
186 // Skip optimizing array inserts for now. Just mark them live.
187 // TODO(greg-lunarg): Eliminate dead array inserts
188 if (op == spv::Op::OpCompositeInsert) {
189 if (typeInst->opcode() == spv::Op::OpTypeArray) {
190 liveInserts_.insert(ii->result_id());
191 continue;
192 }
193 }
194 const uint32_t id = ii->result_id();
195 get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) {
196 if (user->IsCommonDebugInstr()) return;
197 switch (user->opcode()) {
198 case spv::Op::OpCompositeInsert:
199 case spv::Op::OpPhi:
200 // Use by insert or phi does not initiate marking
201 break;
202 case spv::Op::OpCompositeExtract: {
203 // Capture extract indices
204 std::vector<uint32_t> extIndices;
205 uint32_t icnt = 0;
206 user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) {
207 if (icnt > 0) extIndices.push_back(*idp);
208 ++icnt;
209 });
210 // Mark all inserts in chain that intersect with extract
211 std::unordered_set<uint32_t> visited_phis;
212 MarkInsertChain(&*ii, &extIndices, 0, &visited_phis);
213 } break;
214 default: {
215 // Mark inserts in chain for all components
216 MarkInsertChain(&*ii, nullptr, 0, nullptr);
217 } break;
218 }
219 });
220 }
221 }
222 // Find and disconnect dead inserts
223 std::vector<Instruction*> dead_instructions;
224 for (auto bi = func->begin(); bi != func->end(); ++bi) {
225 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
226 if (ii->opcode() != spv::Op::OpCompositeInsert) continue;
227 const uint32_t id = ii->result_id();
228 if (liveInserts_.find(id) != liveInserts_.end()) continue;
229 const uint32_t replId =
230 ii->GetSingleWordInOperand(kInsertCompositeIdInIdx);
231 (void)context()->ReplaceAllUsesWith(id, replId);
232 dead_instructions.push_back(&*ii);
233 modified = true;
234 }
235 }
236 // DCE dead inserts
237 while (!dead_instructions.empty()) {
238 Instruction* inst = dead_instructions.back();
239 dead_instructions.pop_back();
240 DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
241 auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
242 other_inst);
243 if (i != dead_instructions.end()) {
244 dead_instructions.erase(i);
245 }
246 });
247 }
248 return modified;
249 }
250
Process()251 Pass::Status DeadInsertElimPass::Process() {
252 // Process all entry point functions.
253 ProcessFunction pfn = [this](Function* fp) {
254 return EliminateDeadInserts(fp);
255 };
256 bool modified = context()->ProcessReachableCallTree(pfn);
257 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
258 }
259
260 } // namespace opt
261 } // namespace spvtools
262