• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/replace_desc_array_access_using_var_index.h"
16 
17 #include "source/opt/desc_sroa_util.h"
18 #include "source/opt/ir_builder.h"
19 #include "source/util/string_utils.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 constexpr uint32_t kOpAccessChainInOperandIndexes = 1;
25 constexpr uint32_t kOpTypePointerInOperandType = 1;
26 constexpr uint32_t kOpTypeArrayInOperandType = 0;
27 constexpr uint32_t kOpTypeStructInOperandMember = 0;
28 IRContext::Analysis kAnalysisDefUseAndInstrToBlockMapping =
29     IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping;
30 
GetValueWithKeyExistenceCheck(uint32_t key,const std::unordered_map<uint32_t,uint32_t> & map)31 uint32_t GetValueWithKeyExistenceCheck(
32     uint32_t key, const std::unordered_map<uint32_t, uint32_t>& map) {
33   auto itr = map.find(key);
34   assert(itr != map.end() && "Key does not exist");
35   return itr->second;
36 }
37 
38 }  // namespace
39 
Process()40 Pass::Status ReplaceDescArrayAccessUsingVarIndex::Process() {
41   Status status = Status::SuccessWithoutChange;
42   for (Instruction& var : context()->types_values()) {
43     if (descsroautil::IsDescriptorArray(context(), &var)) {
44       if (ReplaceVariableAccessesWithConstantElements(&var))
45         status = Status::SuccessWithChange;
46     }
47   }
48   return status;
49 }
50 
51 bool ReplaceDescArrayAccessUsingVarIndex::
ReplaceVariableAccessesWithConstantElements(Instruction * var) const52     ReplaceVariableAccessesWithConstantElements(Instruction* var) const {
53   std::vector<Instruction*> work_list;
54   get_def_use_mgr()->ForEachUser(var, [&work_list](Instruction* use) {
55     switch (use->opcode()) {
56       case spv::Op::OpAccessChain:
57       case spv::Op::OpInBoundsAccessChain:
58         work_list.push_back(use);
59         break;
60       default:
61         break;
62     }
63   });
64 
65   bool updated = false;
66   for (Instruction* access_chain : work_list) {
67     if (descsroautil::GetAccessChainIndexAsConst(context(), access_chain) ==
68         nullptr) {
69       ReplaceAccessChain(var, access_chain);
70       updated = true;
71     }
72   }
73   // Note that we do not consider OpLoad and OpCompositeExtract because
74   // OpCompositeExtract always has constant literals for indices.
75   return updated;
76 }
77 
ReplaceAccessChain(Instruction * var,Instruction * access_chain) const78 void ReplaceDescArrayAccessUsingVarIndex::ReplaceAccessChain(
79     Instruction* var, Instruction* access_chain) const {
80   uint32_t number_of_elements =
81       descsroautil::GetNumberOfElementsForArrayOrStruct(context(), var);
82   assert(number_of_elements != 0 && "Number of element is 0");
83   if (number_of_elements == 1) {
84     UseConstIndexForAccessChain(access_chain, 0);
85     get_def_use_mgr()->AnalyzeInstUse(access_chain);
86     return;
87   }
88   ReplaceUsersOfAccessChain(access_chain, number_of_elements);
89 }
90 
ReplaceUsersOfAccessChain(Instruction * access_chain,uint32_t number_of_elements) const91 void ReplaceDescArrayAccessUsingVarIndex::ReplaceUsersOfAccessChain(
92     Instruction* access_chain, uint32_t number_of_elements) const {
93   std::vector<Instruction*> final_users;
94   CollectRecursiveUsersWithConcreteType(access_chain, &final_users);
95   for (auto* inst : final_users) {
96     std::deque<Instruction*> insts_to_be_cloned =
97         CollectRequiredImageAndAccessInsts(inst);
98     ReplaceNonUniformAccessWithSwitchCase(
99         inst, access_chain, number_of_elements, insts_to_be_cloned);
100   }
101 }
102 
CollectRecursiveUsersWithConcreteType(Instruction * access_chain,std::vector<Instruction * > * final_users) const103 void ReplaceDescArrayAccessUsingVarIndex::CollectRecursiveUsersWithConcreteType(
104     Instruction* access_chain, std::vector<Instruction*>* final_users) const {
105   std::queue<Instruction*> work_list;
106   work_list.push(access_chain);
107   while (!work_list.empty()) {
108     auto* inst_from_work_list = work_list.front();
109     work_list.pop();
110     get_def_use_mgr()->ForEachUser(
111         inst_from_work_list, [this, final_users, &work_list](Instruction* use) {
112           // TODO: Support Boolean type as well.
113           if (!use->HasResultId() || IsConcreteType(use->type_id())) {
114             final_users->push_back(use);
115           } else {
116             work_list.push(use);
117           }
118         });
119   }
120 }
121 
122 std::deque<Instruction*>
CollectRequiredImageAndAccessInsts(Instruction * user) const123 ReplaceDescArrayAccessUsingVarIndex::CollectRequiredImageAndAccessInsts(
124     Instruction* user) const {
125   std::unordered_set<uint32_t> seen_inst_ids;
126   std::queue<Instruction*> work_list;
127 
128   auto decision_to_include_operand = [this, &seen_inst_ids,
129                                       &work_list](uint32_t* idp) {
130     if (!seen_inst_ids.insert(*idp).second) return;
131     Instruction* operand = get_def_use_mgr()->GetDef(*idp);
132     if (context()->get_instr_block(operand) != nullptr &&
133         (HasImageOrImagePtrType(operand) ||
134          operand->opcode() == spv::Op::OpAccessChain ||
135          operand->opcode() == spv::Op::OpInBoundsAccessChain)) {
136       work_list.push(operand);
137     }
138   };
139 
140   std::deque<Instruction*> required_insts;
141   required_insts.push_front(user);
142   user->ForEachInId(decision_to_include_operand);
143   while (!work_list.empty()) {
144     auto* inst_from_work_list = work_list.front();
145     work_list.pop();
146     required_insts.push_front(inst_from_work_list);
147     inst_from_work_list->ForEachInId(decision_to_include_operand);
148   }
149   return required_insts;
150 }
151 
HasImageOrImagePtrType(const Instruction * inst) const152 bool ReplaceDescArrayAccessUsingVarIndex::HasImageOrImagePtrType(
153     const Instruction* inst) const {
154   assert(inst != nullptr && inst->type_id() != 0 && "Invalid instruction");
155   return IsImageOrImagePtrType(get_def_use_mgr()->GetDef(inst->type_id()));
156 }
157 
IsImageOrImagePtrType(const Instruction * type_inst) const158 bool ReplaceDescArrayAccessUsingVarIndex::IsImageOrImagePtrType(
159     const Instruction* type_inst) const {
160   if (type_inst->opcode() == spv::Op::OpTypeImage ||
161       type_inst->opcode() == spv::Op::OpTypeSampler ||
162       type_inst->opcode() == spv::Op::OpTypeSampledImage) {
163     return true;
164   }
165   if (type_inst->opcode() == spv::Op::OpTypePointer) {
166     Instruction* pointee_type_inst = get_def_use_mgr()->GetDef(
167         type_inst->GetSingleWordInOperand(kOpTypePointerInOperandType));
168     return IsImageOrImagePtrType(pointee_type_inst);
169   }
170   if (type_inst->opcode() == spv::Op::OpTypeArray) {
171     Instruction* element_type_inst = get_def_use_mgr()->GetDef(
172         type_inst->GetSingleWordInOperand(kOpTypeArrayInOperandType));
173     return IsImageOrImagePtrType(element_type_inst);
174   }
175   if (type_inst->opcode() != spv::Op::OpTypeStruct) return false;
176   for (uint32_t in_operand_idx = kOpTypeStructInOperandMember;
177        in_operand_idx < type_inst->NumInOperands(); ++in_operand_idx) {
178     Instruction* member_type_inst = get_def_use_mgr()->GetDef(
179         type_inst->GetSingleWordInOperand(kOpTypeStructInOperandMember));
180     if (IsImageOrImagePtrType(member_type_inst)) return true;
181   }
182   return false;
183 }
184 
IsConcreteType(uint32_t type_id) const185 bool ReplaceDescArrayAccessUsingVarIndex::IsConcreteType(
186     uint32_t type_id) const {
187   Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
188   if (type_inst->opcode() == spv::Op::OpTypeInt ||
189       type_inst->opcode() == spv::Op::OpTypeFloat) {
190     return true;
191   }
192   if (type_inst->opcode() == spv::Op::OpTypeVector ||
193       type_inst->opcode() == spv::Op::OpTypeMatrix ||
194       type_inst->opcode() == spv::Op::OpTypeArray) {
195     return IsConcreteType(type_inst->GetSingleWordInOperand(0));
196   }
197   if (type_inst->opcode() == spv::Op::OpTypeStruct) {
198     for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
199       if (!IsConcreteType(type_inst->GetSingleWordInOperand(i))) return false;
200     }
201     return true;
202   }
203   return false;
204 }
205 
CreateCaseBlock(Instruction * access_chain,uint32_t element_index,const std::deque<Instruction * > & insts_to_be_cloned,uint32_t branch_target_id,std::unordered_map<uint32_t,uint32_t> * old_ids_to_new_ids) const206 BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateCaseBlock(
207     Instruction* access_chain, uint32_t element_index,
208     const std::deque<Instruction*>& insts_to_be_cloned,
209     uint32_t branch_target_id,
210     std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const {
211   auto* case_block = CreateNewBlock();
212   AddConstElementAccessToCaseBlock(case_block, access_chain, element_index,
213                                    old_ids_to_new_ids);
214   CloneInstsToBlock(case_block, access_chain, insts_to_be_cloned,
215                     old_ids_to_new_ids);
216   AddBranchToBlock(case_block, branch_target_id);
217   UseNewIdsInBlock(case_block, *old_ids_to_new_ids);
218   return case_block;
219 }
220 
CloneInstsToBlock(BasicBlock * block,Instruction * inst_to_skip_cloning,const std::deque<Instruction * > & insts_to_be_cloned,std::unordered_map<uint32_t,uint32_t> * old_ids_to_new_ids) const221 void ReplaceDescArrayAccessUsingVarIndex::CloneInstsToBlock(
222     BasicBlock* block, Instruction* inst_to_skip_cloning,
223     const std::deque<Instruction*>& insts_to_be_cloned,
224     std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const {
225   for (auto* inst_to_be_cloned : insts_to_be_cloned) {
226     if (inst_to_be_cloned == inst_to_skip_cloning) continue;
227     std::unique_ptr<Instruction> clone(inst_to_be_cloned->Clone(context()));
228     if (inst_to_be_cloned->HasResultId()) {
229       uint32_t new_id = context()->TakeNextId();
230       clone->SetResultId(new_id);
231       (*old_ids_to_new_ids)[inst_to_be_cloned->result_id()] = new_id;
232     }
233     get_def_use_mgr()->AnalyzeInstDefUse(clone.get());
234     context()->set_instr_block(clone.get(), block);
235     block->AddInstruction(std::move(clone));
236   }
237 }
238 
UseNewIdsInBlock(BasicBlock * block,const std::unordered_map<uint32_t,uint32_t> & old_ids_to_new_ids) const239 void ReplaceDescArrayAccessUsingVarIndex::UseNewIdsInBlock(
240     BasicBlock* block,
241     const std::unordered_map<uint32_t, uint32_t>& old_ids_to_new_ids) const {
242   for (auto block_itr = block->begin(); block_itr != block->end();
243        ++block_itr) {
244     (&*block_itr)->ForEachInId([&old_ids_to_new_ids](uint32_t* idp) {
245       auto old_ids_to_new_ids_itr = old_ids_to_new_ids.find(*idp);
246       if (old_ids_to_new_ids_itr == old_ids_to_new_ids.end()) return;
247       *idp = old_ids_to_new_ids_itr->second;
248     });
249     get_def_use_mgr()->AnalyzeInstUse(&*block_itr);
250   }
251 }
252 
ReplaceNonUniformAccessWithSwitchCase(Instruction * access_chain_final_user,Instruction * access_chain,uint32_t number_of_elements,const std::deque<Instruction * > & insts_to_be_cloned) const253 void ReplaceDescArrayAccessUsingVarIndex::ReplaceNonUniformAccessWithSwitchCase(
254     Instruction* access_chain_final_user, Instruction* access_chain,
255     uint32_t number_of_elements,
256     const std::deque<Instruction*>& insts_to_be_cloned) const {
257   auto* block = context()->get_instr_block(access_chain_final_user);
258   // If the instruction does not belong to a block (i.e. in the case of
259   // OpDecorate), no replacement is needed.
260   if (!block) return;
261 
262   // Create merge block and add terminator
263   auto* merge_block = SeparateInstructionsIntoNewBlock(
264       block, access_chain_final_user->NextNode());
265 
266   auto* function = block->GetParent();
267 
268   // Add case blocks
269   std::vector<uint32_t> phi_operands;
270   std::vector<uint32_t> case_block_ids;
271   for (uint32_t idx = 0; idx < number_of_elements; ++idx) {
272     std::unordered_map<uint32_t, uint32_t> old_ids_to_new_ids_for_cloned_insts;
273     std::unique_ptr<BasicBlock> case_block(CreateCaseBlock(
274         access_chain, idx, insts_to_be_cloned, merge_block->id(),
275         &old_ids_to_new_ids_for_cloned_insts));
276     case_block_ids.push_back(case_block->id());
277     function->InsertBasicBlockBefore(std::move(case_block), merge_block);
278 
279     // Keep the operand for OpPhi
280     if (!access_chain_final_user->HasResultId()) continue;
281     uint32_t phi_operand =
282         GetValueWithKeyExistenceCheck(access_chain_final_user->result_id(),
283                                       old_ids_to_new_ids_for_cloned_insts);
284     phi_operands.push_back(phi_operand);
285   }
286 
287   // Create default block
288   std::unique_ptr<BasicBlock> default_block(
289       CreateDefaultBlock(access_chain_final_user->HasResultId(), &phi_operands,
290                          merge_block->id()));
291   uint32_t default_block_id = default_block->id();
292   function->InsertBasicBlockBefore(std::move(default_block), merge_block);
293 
294   // Create OpSwitch
295   uint32_t access_chain_index_var_id =
296       descsroautil::GetFirstIndexOfAccessChain(access_chain);
297   AddSwitchForAccessChain(block, access_chain_index_var_id, default_block_id,
298                           merge_block->id(), case_block_ids);
299 
300   // Create phi instructions
301   if (!phi_operands.empty()) {
302     uint32_t phi_id = CreatePhiInstruction(merge_block, phi_operands,
303                                            case_block_ids, default_block_id);
304     context()->ReplaceAllUsesWith(access_chain_final_user->result_id(), phi_id);
305   }
306 
307   // Replace OpPhi incoming block operand that uses |block| with |merge_block|
308   ReplacePhiIncomingBlock(block->id(), merge_block->id());
309 }
310 
311 BasicBlock*
SeparateInstructionsIntoNewBlock(BasicBlock * block,Instruction * separation_begin_inst) const312 ReplaceDescArrayAccessUsingVarIndex::SeparateInstructionsIntoNewBlock(
313     BasicBlock* block, Instruction* separation_begin_inst) const {
314   auto separation_begin = block->begin();
315   while (separation_begin != block->end() &&
316          &*separation_begin != separation_begin_inst) {
317     ++separation_begin;
318   }
319   return block->SplitBasicBlock(context(), context()->TakeNextId(),
320                                 separation_begin);
321 }
322 
CreateNewBlock() const323 BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateNewBlock() const {
324   auto* new_block = new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
325       context(), spv::Op::OpLabel, 0, context()->TakeNextId(), {})));
326   get_def_use_mgr()->AnalyzeInstDefUse(new_block->GetLabelInst());
327   context()->set_instr_block(new_block->GetLabelInst(), new_block);
328   return new_block;
329 }
330 
UseConstIndexForAccessChain(Instruction * access_chain,uint32_t const_element_idx) const331 void ReplaceDescArrayAccessUsingVarIndex::UseConstIndexForAccessChain(
332     Instruction* access_chain, uint32_t const_element_idx) const {
333   uint32_t const_element_idx_id =
334       context()->get_constant_mgr()->GetUIntConstId(const_element_idx);
335   access_chain->SetInOperand(kOpAccessChainInOperandIndexes,
336                              {const_element_idx_id});
337 }
338 
AddConstElementAccessToCaseBlock(BasicBlock * case_block,Instruction * access_chain,uint32_t const_element_idx,std::unordered_map<uint32_t,uint32_t> * old_ids_to_new_ids) const339 void ReplaceDescArrayAccessUsingVarIndex::AddConstElementAccessToCaseBlock(
340     BasicBlock* case_block, Instruction* access_chain,
341     uint32_t const_element_idx,
342     std::unordered_map<uint32_t, uint32_t>* old_ids_to_new_ids) const {
343   std::unique_ptr<Instruction> access_clone(access_chain->Clone(context()));
344   UseConstIndexForAccessChain(access_clone.get(), const_element_idx);
345 
346   uint32_t new_access_id = context()->TakeNextId();
347   (*old_ids_to_new_ids)[access_clone->result_id()] = new_access_id;
348   access_clone->SetResultId(new_access_id);
349   get_def_use_mgr()->AnalyzeInstDefUse(access_clone.get());
350 
351   context()->set_instr_block(access_clone.get(), case_block);
352   case_block->AddInstruction(std::move(access_clone));
353 }
354 
AddBranchToBlock(BasicBlock * parent_block,uint32_t branch_destination) const355 void ReplaceDescArrayAccessUsingVarIndex::AddBranchToBlock(
356     BasicBlock* parent_block, uint32_t branch_destination) const {
357   InstructionBuilder builder{context(), parent_block,
358                              kAnalysisDefUseAndInstrToBlockMapping};
359   builder.AddBranch(branch_destination);
360 }
361 
CreateDefaultBlock(bool null_const_for_phi_is_needed,std::vector<uint32_t> * phi_operands,uint32_t merge_block_id) const362 BasicBlock* ReplaceDescArrayAccessUsingVarIndex::CreateDefaultBlock(
363     bool null_const_for_phi_is_needed, std::vector<uint32_t>* phi_operands,
364     uint32_t merge_block_id) const {
365   auto* default_block = CreateNewBlock();
366   AddBranchToBlock(default_block, merge_block_id);
367   if (!null_const_for_phi_is_needed) return default_block;
368 
369   // Create null value for OpPhi
370   Instruction* inst = context()->get_def_use_mgr()->GetDef((*phi_operands)[0]);
371   auto* null_const_inst = GetConstNull(inst->type_id());
372   phi_operands->push_back(null_const_inst->result_id());
373   return default_block;
374 }
375 
GetConstNull(uint32_t type_id) const376 Instruction* ReplaceDescArrayAccessUsingVarIndex::GetConstNull(
377     uint32_t type_id) const {
378   assert(type_id != 0 && "Result type is expected");
379   auto* type = context()->get_type_mgr()->GetType(type_id);
380   auto* null_const = context()->get_constant_mgr()->GetConstant(type, {});
381   return context()->get_constant_mgr()->GetDefiningInstruction(null_const);
382 }
383 
AddSwitchForAccessChain(BasicBlock * parent_block,uint32_t access_chain_index_var_id,uint32_t default_id,uint32_t merge_id,const std::vector<uint32_t> & case_block_ids) const384 void ReplaceDescArrayAccessUsingVarIndex::AddSwitchForAccessChain(
385     BasicBlock* parent_block, uint32_t access_chain_index_var_id,
386     uint32_t default_id, uint32_t merge_id,
387     const std::vector<uint32_t>& case_block_ids) const {
388   InstructionBuilder builder{context(), parent_block,
389                              kAnalysisDefUseAndInstrToBlockMapping};
390   std::vector<std::pair<Operand::OperandData, uint32_t>> cases;
391   for (uint32_t i = 0; i < static_cast<uint32_t>(case_block_ids.size()); ++i) {
392     cases.emplace_back(Operand::OperandData{i}, case_block_ids[i]);
393   }
394   builder.AddSwitch(access_chain_index_var_id, default_id, cases, merge_id);
395 }
396 
CreatePhiInstruction(BasicBlock * parent_block,const std::vector<uint32_t> & phi_operands,const std::vector<uint32_t> & case_block_ids,uint32_t default_block_id) const397 uint32_t ReplaceDescArrayAccessUsingVarIndex::CreatePhiInstruction(
398     BasicBlock* parent_block, const std::vector<uint32_t>& phi_operands,
399     const std::vector<uint32_t>& case_block_ids,
400     uint32_t default_block_id) const {
401   std::vector<uint32_t> incomings;
402   assert(case_block_ids.size() + 1 == phi_operands.size() &&
403          "Number of Phi operands must be exactly 1 bigger than the one of case "
404          "blocks");
405   for (size_t i = 0; i < case_block_ids.size(); ++i) {
406     incomings.push_back(phi_operands[i]);
407     incomings.push_back(case_block_ids[i]);
408   }
409   incomings.push_back(phi_operands.back());
410   incomings.push_back(default_block_id);
411 
412   InstructionBuilder builder{context(), &*parent_block->begin(),
413                              kAnalysisDefUseAndInstrToBlockMapping};
414   uint32_t phi_result_type_id =
415       context()->get_def_use_mgr()->GetDef(phi_operands[0])->type_id();
416   auto* phi = builder.AddPhi(phi_result_type_id, incomings);
417   return phi->result_id();
418 }
419 
ReplacePhiIncomingBlock(uint32_t old_incoming_block_id,uint32_t new_incoming_block_id) const420 void ReplaceDescArrayAccessUsingVarIndex::ReplacePhiIncomingBlock(
421     uint32_t old_incoming_block_id, uint32_t new_incoming_block_id) const {
422   context()->ReplaceAllUsesWithPredicate(
423       old_incoming_block_id, new_incoming_block_id,
424       [](Instruction* use) { return use->opcode() == spv::Op::OpPhi; });
425 }
426 
427 }  // namespace opt
428 }  // namespace spvtools
429