• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/dead_insert_elim_pass.h"
18 
19 #include "source/opt/composite.h"
20 #include "source/opt/ir_context.h"
21 #include "source/opt/iterator.h"
22 #include "spirv/1.2/GLSL.std.450.h"
23 
24 namespace spvtools {
25 namespace opt {
26 
27 namespace {
28 
29 const uint32_t kTypeVectorCountInIdx = 1;
30 const uint32_t kTypeMatrixCountInIdx = 1;
31 const uint32_t kTypeArrayLengthIdInIdx = 1;
32 const uint32_t kTypeIntWidthInIdx = 0;
33 const uint32_t kConstantValueInIdx = 0;
34 const uint32_t kInsertObjectIdInIdx = 0;
35 const uint32_t kInsertCompositeIdInIdx = 1;
36 
37 }  // anonymous namespace
38 
NumComponents(Instruction * typeInst)39 uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) {
40   switch (typeInst->opcode()) {
41     case SpvOpTypeVector: {
42       return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx);
43     } break;
44     case SpvOpTypeMatrix: {
45       return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx);
46     } break;
47     case SpvOpTypeArray: {
48       uint32_t lenId =
49           typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx);
50       Instruction* lenInst = get_def_use_mgr()->GetDef(lenId);
51       if (lenInst->opcode() != SpvOpConstant) return 0;
52       uint32_t lenTypeId = lenInst->type_id();
53       Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId);
54       // TODO(greg-lunarg): Support non-32-bit array length
55       if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
56         return 0;
57       return lenInst->GetSingleWordInOperand(kConstantValueInIdx);
58     } break;
59     case SpvOpTypeStruct: {
60       return typeInst->NumInOperands();
61     } break;
62     default: { return 0; } break;
63   }
64 }
65 
MarkInsertChain(Instruction * insertChain,std::vector<uint32_t> * pExtIndices,uint32_t extOffset,std::unordered_set<uint32_t> * visited_phis)66 void DeadInsertElimPass::MarkInsertChain(
67     Instruction* insertChain, std::vector<uint32_t>* pExtIndices,
68     uint32_t extOffset, std::unordered_set<uint32_t>* visited_phis) {
69   // Not currently optimizing array inserts.
70   Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id());
71   if (typeInst->opcode() == SpvOpTypeArray) return;
72   // Insert chains are only composed of inserts and phis
73   if (insertChain->opcode() != SpvOpCompositeInsert &&
74       insertChain->opcode() != SpvOpPhi)
75     return;
76   // If extract indices are empty, mark all subcomponents if type
77   // is constant length.
78   if (pExtIndices == nullptr) {
79     uint32_t cnum = NumComponents(typeInst);
80     if (cnum > 0) {
81       std::vector<uint32_t> extIndices;
82       for (uint32_t i = 0; i < cnum; i++) {
83         extIndices.clear();
84         extIndices.push_back(i);
85         std::unordered_set<uint32_t> sub_visited_phis;
86         MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis);
87       }
88       return;
89     }
90   }
91   Instruction* insInst = insertChain;
92   while (insInst->opcode() == SpvOpCompositeInsert) {
93     // If no extract indices, mark insert and inserted object (which might
94     // also be an insert chain) and continue up the chain though the input
95     // composite.
96     //
97     // Note: We mark inserted objects in this function (rather than in
98     // EliminateDeadInsertsOnePass) because in some cases, we can do it
99     // more accurately here.
100     if (pExtIndices == nullptr) {
101       liveInserts_.insert(insInst->result_id());
102       uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
103       std::unordered_set<uint32_t> obj_visited_phis;
104       MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
105                       &obj_visited_phis);
106     // If extract indices match insert, we are done. Mark insert and
107     // inserted object.
108     } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) {
109       liveInserts_.insert(insInst->result_id());
110       uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
111       std::unordered_set<uint32_t> obj_visited_phis;
112       MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
113                       &obj_visited_phis);
114       break;
115     // If non-matching intersection, mark insert
116     } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) {
117       liveInserts_.insert(insInst->result_id());
118       // If more extract indices than insert, we are done. Use remaining
119       // extract indices to mark inserted object.
120       uint32_t numInsertIndices = insInst->NumInOperands() - 2;
121       if (pExtIndices->size() - extOffset > numInsertIndices) {
122         uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
123         std::unordered_set<uint32_t> obj_visited_phis;
124         MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices,
125                         extOffset + numInsertIndices, &obj_visited_phis);
126         break;
127       // If fewer extract indices than insert, also mark inserted object and
128       // continue up chain.
129       } else {
130         uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
131         std::unordered_set<uint32_t> obj_visited_phis;
132         MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
133                         &obj_visited_phis);
134       }
135     }
136     // Get next insert in chain
137     const uint32_t compId =
138         insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
139     insInst = get_def_use_mgr()->GetDef(compId);
140   }
141   // If insert chain ended with phi, do recursive call on each operand
142   if (insInst->opcode() != SpvOpPhi) return;
143   // Mark phi visited to prevent potential infinite loop. If phi is already
144   // visited, return to avoid infinite loop.
145   if (visited_phis->count(insInst->result_id()) != 0) return;
146   visited_phis->insert(insInst->result_id());
147 
148   // Phis may have duplicate inputs values for different edges, prune incoming
149   // ids lists before recursing.
150   std::vector<uint32_t> ids;
151   for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) {
152     ids.push_back(insInst->GetSingleWordInOperand(i));
153   }
154   std::sort(ids.begin(), ids.end());
155   auto new_end = std::unique(ids.begin(), ids.end());
156   for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) {
157     Instruction* pi = get_def_use_mgr()->GetDef(*id_iter);
158     MarkInsertChain(pi, pExtIndices, extOffset, visited_phis);
159   }
160 }
161 
EliminateDeadInserts(Function * func)162 bool DeadInsertElimPass::EliminateDeadInserts(Function* func) {
163   bool modified = false;
164   bool lastmodified = true;
165   // Each pass can delete dead instructions, thus potentially revealing
166   // new dead insertions ie insertions with no uses.
167   while (lastmodified) {
168     lastmodified = EliminateDeadInsertsOnePass(func);
169     modified |= lastmodified;
170   }
171   return modified;
172 }
173 
EliminateDeadInsertsOnePass(Function * func)174 bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) {
175   bool modified = false;
176   liveInserts_.clear();
177   visitedPhis_.clear();
178   // Mark all live inserts
179   for (auto bi = func->begin(); bi != func->end(); ++bi) {
180     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
181       // Only process Inserts and composite Phis
182       SpvOp op = ii->opcode();
183       Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id());
184       if (op != SpvOpCompositeInsert &&
185           (op != SpvOpPhi || !spvOpcodeIsComposite(typeInst->opcode())))
186         continue;
187       // The marking algorithm can be expensive for large arrays and the
188       // efficacy of eliminating dead inserts into arrays is questionable.
189       // Skip optimizing array inserts for now. Just mark them live.
190       // TODO(greg-lunarg): Eliminate dead array inserts
191       if (op == SpvOpCompositeInsert) {
192         if (typeInst->opcode() == SpvOpTypeArray) {
193           liveInserts_.insert(ii->result_id());
194           continue;
195         }
196       }
197       const uint32_t id = ii->result_id();
198       get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) {
199         switch (user->opcode()) {
200           case SpvOpCompositeInsert:
201           case SpvOpPhi:
202             // Use by insert or phi does not initiate marking
203             break;
204           case SpvOpCompositeExtract: {
205             // Capture extract indices
206             std::vector<uint32_t> extIndices;
207             uint32_t icnt = 0;
208             user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) {
209               if (icnt > 0) extIndices.push_back(*idp);
210               ++icnt;
211             });
212             // Mark all inserts in chain that intersect with extract
213             std::unordered_set<uint32_t> visited_phis;
214             MarkInsertChain(&*ii, &extIndices, 0, &visited_phis);
215           } break;
216           default: {
217             // Mark inserts in chain for all components
218             MarkInsertChain(&*ii, nullptr, 0, nullptr);
219           } break;
220         }
221       });
222     }
223   }
224   // Find and disconnect dead inserts
225   std::vector<Instruction*> dead_instructions;
226   for (auto bi = func->begin(); bi != func->end(); ++bi) {
227     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
228       if (ii->opcode() != SpvOpCompositeInsert) continue;
229       const uint32_t id = ii->result_id();
230       if (liveInserts_.find(id) != liveInserts_.end()) continue;
231       const uint32_t replId =
232           ii->GetSingleWordInOperand(kInsertCompositeIdInIdx);
233       (void)context()->ReplaceAllUsesWith(id, replId);
234       dead_instructions.push_back(&*ii);
235       modified = true;
236     }
237   }
238   // DCE dead inserts
239   while (!dead_instructions.empty()) {
240     Instruction* inst = dead_instructions.back();
241     dead_instructions.pop_back();
242     DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
243       auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
244                          other_inst);
245       if (i != dead_instructions.end()) {
246         dead_instructions.erase(i);
247       }
248     });
249   }
250   return modified;
251 }
252 
Process()253 Pass::Status DeadInsertElimPass::Process() {
254   // Process all entry point functions.
255   ProcessFunction pfn = [this](Function* fp) {
256     return EliminateDeadInserts(fp);
257   };
258   bool modified = context()->ProcessEntryPointCallTree(pfn);
259   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
260 }
261 
262 }  // namespace opt
263 }  // namespace spvtools
264