• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/dead_insert_elim_pass.h"
18 
19 #include "source/opt/composite.h"
20 #include "source/opt/ir_context.h"
21 #include "source/opt/iterator.h"
22 #include "spirv/1.2/GLSL.std.450.h"
23 
24 namespace spvtools {
25 namespace opt {
26 namespace {
27 constexpr uint32_t kTypeVectorCountInIdx = 1;
28 constexpr uint32_t kTypeMatrixCountInIdx = 1;
29 constexpr uint32_t kTypeArrayLengthIdInIdx = 1;
30 constexpr uint32_t kTypeIntWidthInIdx = 0;
31 constexpr uint32_t kConstantValueInIdx = 0;
32 constexpr uint32_t kInsertObjectIdInIdx = 0;
33 constexpr uint32_t kInsertCompositeIdInIdx = 1;
34 }  // namespace
35 
NumComponents(Instruction * typeInst)36 uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) {
37   switch (typeInst->opcode()) {
38     case spv::Op::OpTypeVector: {
39       return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx);
40     } break;
41     case spv::Op::OpTypeMatrix: {
42       return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx);
43     } break;
44     case spv::Op::OpTypeArray: {
45       uint32_t lenId =
46           typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx);
47       Instruction* lenInst = get_def_use_mgr()->GetDef(lenId);
48       if (lenInst->opcode() != spv::Op::OpConstant) return 0;
49       uint32_t lenTypeId = lenInst->type_id();
50       Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId);
51       // TODO(greg-lunarg): Support non-32-bit array length
52       if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
53         return 0;
54       return lenInst->GetSingleWordInOperand(kConstantValueInIdx);
55     } break;
56     case spv::Op::OpTypeStruct: {
57       return typeInst->NumInOperands();
58     } break;
59     default: { return 0; } break;
60   }
61 }
62 
MarkInsertChain(Instruction * insertChain,std::vector<uint32_t> * pExtIndices,uint32_t extOffset,std::unordered_set<uint32_t> * visited_phis)63 void DeadInsertElimPass::MarkInsertChain(
64     Instruction* insertChain, std::vector<uint32_t>* pExtIndices,
65     uint32_t extOffset, std::unordered_set<uint32_t>* visited_phis) {
66   // Not currently optimizing array inserts.
67   Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id());
68   if (typeInst->opcode() == spv::Op::OpTypeArray) return;
69   // Insert chains are only composed of inserts and phis
70   if (insertChain->opcode() != spv::Op::OpCompositeInsert &&
71       insertChain->opcode() != spv::Op::OpPhi)
72     return;
73   // If extract indices are empty, mark all subcomponents if type
74   // is constant length.
75   if (pExtIndices == nullptr) {
76     uint32_t cnum = NumComponents(typeInst);
77     if (cnum > 0) {
78       std::vector<uint32_t> extIndices;
79       for (uint32_t i = 0; i < cnum; i++) {
80         extIndices.clear();
81         extIndices.push_back(i);
82         std::unordered_set<uint32_t> sub_visited_phis;
83         MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis);
84       }
85       return;
86     }
87   }
88   Instruction* insInst = insertChain;
89   while (insInst->opcode() == spv::Op::OpCompositeInsert) {
90     // If no extract indices, mark insert and inserted object (which might
91     // also be an insert chain) and continue up the chain though the input
92     // composite.
93     //
94     // Note: We mark inserted objects in this function (rather than in
95     // EliminateDeadInsertsOnePass) because in some cases, we can do it
96     // more accurately here.
97     if (pExtIndices == nullptr) {
98       liveInserts_.insert(insInst->result_id());
99       uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
100       std::unordered_set<uint32_t> obj_visited_phis;
101       MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
102                       &obj_visited_phis);
103     // If extract indices match insert, we are done. Mark insert and
104     // inserted object.
105     } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) {
106       liveInserts_.insert(insInst->result_id());
107       uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
108       std::unordered_set<uint32_t> obj_visited_phis;
109       MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
110                       &obj_visited_phis);
111       break;
112     // If non-matching intersection, mark insert
113     } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) {
114       liveInserts_.insert(insInst->result_id());
115       // If more extract indices than insert, we are done. Use remaining
116       // extract indices to mark inserted object.
117       uint32_t numInsertIndices = insInst->NumInOperands() - 2;
118       if (pExtIndices->size() - extOffset > numInsertIndices) {
119         uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
120         std::unordered_set<uint32_t> obj_visited_phis;
121         MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices,
122                         extOffset + numInsertIndices, &obj_visited_phis);
123         break;
124       // If fewer extract indices than insert, also mark inserted object and
125       // continue up chain.
126       } else {
127         uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
128         std::unordered_set<uint32_t> obj_visited_phis;
129         MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
130                         &obj_visited_phis);
131       }
132     }
133     // Get next insert in chain
134     const uint32_t compId =
135         insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
136     insInst = get_def_use_mgr()->GetDef(compId);
137   }
138   // If insert chain ended with phi, do recursive call on each operand
139   if (insInst->opcode() != spv::Op::OpPhi) return;
140   // Mark phi visited to prevent potential infinite loop. If phi is already
141   // visited, return to avoid infinite loop.
142   if (visited_phis->count(insInst->result_id()) != 0) return;
143   visited_phis->insert(insInst->result_id());
144 
145   // Phis may have duplicate inputs values for different edges, prune incoming
146   // ids lists before recursing.
147   std::vector<uint32_t> ids;
148   for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) {
149     ids.push_back(insInst->GetSingleWordInOperand(i));
150   }
151   std::sort(ids.begin(), ids.end());
152   auto new_end = std::unique(ids.begin(), ids.end());
153   for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) {
154     Instruction* pi = get_def_use_mgr()->GetDef(*id_iter);
155     MarkInsertChain(pi, pExtIndices, extOffset, visited_phis);
156   }
157 }
158 
EliminateDeadInserts(Function * func)159 bool DeadInsertElimPass::EliminateDeadInserts(Function* func) {
160   bool modified = false;
161   bool lastmodified = true;
162   // Each pass can delete dead instructions, thus potentially revealing
163   // new dead insertions ie insertions with no uses.
164   while (lastmodified) {
165     lastmodified = EliminateDeadInsertsOnePass(func);
166     modified |= lastmodified;
167   }
168   return modified;
169 }
170 
EliminateDeadInsertsOnePass(Function * func)171 bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) {
172   bool modified = false;
173   liveInserts_.clear();
174   visitedPhis_.clear();
175   // Mark all live inserts
176   for (auto bi = func->begin(); bi != func->end(); ++bi) {
177     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
178       // Only process Inserts and composite Phis
179       spv::Op op = ii->opcode();
180       Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id());
181       if (op != spv::Op::OpCompositeInsert &&
182           (op != spv::Op::OpPhi || !spvOpcodeIsComposite(typeInst->opcode())))
183         continue;
184       // The marking algorithm can be expensive for large arrays and the
185       // efficacy of eliminating dead inserts into arrays is questionable.
186       // Skip optimizing array inserts for now. Just mark them live.
187       // TODO(greg-lunarg): Eliminate dead array inserts
188       if (op == spv::Op::OpCompositeInsert) {
189         if (typeInst->opcode() == spv::Op::OpTypeArray) {
190           liveInserts_.insert(ii->result_id());
191           continue;
192         }
193       }
194       const uint32_t id = ii->result_id();
195       get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) {
196         if (user->IsCommonDebugInstr()) return;
197         switch (user->opcode()) {
198           case spv::Op::OpCompositeInsert:
199           case spv::Op::OpPhi:
200             // Use by insert or phi does not initiate marking
201             break;
202           case spv::Op::OpCompositeExtract: {
203             // Capture extract indices
204             std::vector<uint32_t> extIndices;
205             uint32_t icnt = 0;
206             user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) {
207               if (icnt > 0) extIndices.push_back(*idp);
208               ++icnt;
209             });
210             // Mark all inserts in chain that intersect with extract
211             std::unordered_set<uint32_t> visited_phis;
212             MarkInsertChain(&*ii, &extIndices, 0, &visited_phis);
213           } break;
214           default: {
215             // Mark inserts in chain for all components
216             MarkInsertChain(&*ii, nullptr, 0, nullptr);
217           } break;
218         }
219       });
220     }
221   }
222   // Find and disconnect dead inserts
223   std::vector<Instruction*> dead_instructions;
224   for (auto bi = func->begin(); bi != func->end(); ++bi) {
225     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
226       if (ii->opcode() != spv::Op::OpCompositeInsert) continue;
227       const uint32_t id = ii->result_id();
228       if (liveInserts_.find(id) != liveInserts_.end()) continue;
229       const uint32_t replId =
230           ii->GetSingleWordInOperand(kInsertCompositeIdInIdx);
231       (void)context()->ReplaceAllUsesWith(id, replId);
232       dead_instructions.push_back(&*ii);
233       modified = true;
234     }
235   }
236   // DCE dead inserts
237   while (!dead_instructions.empty()) {
238     Instruction* inst = dead_instructions.back();
239     dead_instructions.pop_back();
240     DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
241       auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
242                          other_inst);
243       if (i != dead_instructions.end()) {
244         dead_instructions.erase(i);
245       }
246     });
247   }
248   return modified;
249 }
250 
Process()251 Pass::Status DeadInsertElimPass::Process() {
252   // Process all entry point functions.
253   ProcessFunction pfn = [this](Function* fp) {
254     return EliminateDeadInserts(fp);
255   };
256   bool modified = context()->ProcessReachableCallTree(pfn);
257   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
258 }
259 
260 }  // namespace opt
261 }  // namespace spvtools
262