• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/local_access_chain_convert_pass.h"
18 
19 #include "ir_builder.h"
20 #include "ir_context.h"
21 #include "iterator.h"
22 
23 namespace spvtools {
24 namespace opt {
25 
26 namespace {
27 
28 const uint32_t kStoreValIdInIdx = 1;
29 const uint32_t kAccessChainPtrIdInIdx = 0;
30 const uint32_t kConstantValueInIdx = 0;
31 const uint32_t kTypeIntWidthInIdx = 0;
32 
33 }  // anonymous namespace
34 
BuildAndAppendInst(SpvOp opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)35 void LocalAccessChainConvertPass::BuildAndAppendInst(
36     SpvOp opcode, uint32_t typeId, uint32_t resultId,
37     const std::vector<Operand>& in_opnds,
38     std::vector<std::unique_ptr<Instruction>>* newInsts) {
39   std::unique_ptr<Instruction> newInst(
40       new Instruction(context(), opcode, typeId, resultId, in_opnds));
41   get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
42   newInsts->emplace_back(std::move(newInst));
43 }
44 
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)45 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
46     const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
47     std::vector<std::unique_ptr<Instruction>>* newInsts) {
48   const uint32_t ldResultId = TakeNextId();
49   if (ldResultId == 0) {
50     return 0;
51   }
52 
53   *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
54   const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
55   assert(varInst->opcode() == SpvOpVariable);
56   *varPteTypeId = GetPointeeTypeId(varInst);
57   BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
58                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
59                      newInsts);
60   return ldResultId;
61 }
62 
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)63 void LocalAccessChainConvertPass::AppendConstantOperands(
64     const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
65   uint32_t iidIdx = 0;
66   ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
67     if (iidIdx > 0) {
68       const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
69       uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
70       in_opnds->push_back(
71           {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
72     }
73     ++iidIdx;
74   });
75 }
76 
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)77 bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
78     const Instruction* address_inst, Instruction* original_load) {
79   // Build and append load of variable in ptrInst
80   if (address_inst->NumInOperands() == 1) {
81     // An access chain with no indices is essentially a copy.  All that is
82     // needed is to propagate the address.
83     context()->ReplaceAllUsesWith(
84         address_inst->result_id(),
85         address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
86     return true;
87   }
88 
89   std::vector<std::unique_ptr<Instruction>> new_inst;
90   uint32_t varId;
91   uint32_t varPteTypeId;
92   const uint32_t ldResultId =
93       BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
94   if (ldResultId == 0) {
95     return false;
96   }
97 
98   new_inst[0]->UpdateDebugInfoFrom(original_load);
99   context()->get_decoration_mgr()->CloneDecorations(
100       original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision});
101   original_load->InsertBefore(std::move(new_inst));
102   context()->get_debug_info_mgr()->AnalyzeDebugInst(
103       original_load->PreviousNode());
104 
105   // Rewrite |original_load| into an extract.
106   Instruction::OperandList new_operands;
107 
108   // copy the result id and the type id to the new operand list.
109   new_operands.emplace_back(original_load->GetOperand(0));
110   new_operands.emplace_back(original_load->GetOperand(1));
111 
112   new_operands.emplace_back(
113       Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
114   AppendConstantOperands(address_inst, &new_operands);
115   original_load->SetOpcode(SpvOpCompositeExtract);
116   original_load->ReplaceOperands(new_operands);
117   context()->UpdateDefUse(original_load);
118   return true;
119 }
120 
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)121 bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
122     const Instruction* ptrInst, uint32_t valId,
123     std::vector<std::unique_ptr<Instruction>>* newInsts) {
124   if (ptrInst->NumInOperands() == 1) {
125     // An access chain with no indices is essentially a copy.  However, we still
126     // have to create a new store because the old ones will be deleted.
127     BuildAndAppendInst(
128         SpvOpStore, 0, 0,
129         {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
130           {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
131          {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
132         newInsts);
133     return true;
134   }
135 
136   // Build and append load of variable in ptrInst
137   uint32_t varId;
138   uint32_t varPteTypeId;
139   const uint32_t ldResultId =
140       BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
141   if (ldResultId == 0) {
142     return false;
143   }
144 
145   context()->get_decoration_mgr()->CloneDecorations(
146       varId, ldResultId, {SpvDecorationRelaxedPrecision});
147 
148   // Build and append Insert
149   const uint32_t insResultId = TakeNextId();
150   if (insResultId == 0) {
151     return false;
152   }
153   std::vector<Operand> ins_in_opnds = {
154       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
155       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
156   AppendConstantOperands(ptrInst, &ins_in_opnds);
157   BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId,
158                      ins_in_opnds, newInsts);
159 
160   context()->get_decoration_mgr()->CloneDecorations(
161       varId, insResultId, {SpvDecorationRelaxedPrecision});
162 
163   // Build and append Store
164   BuildAndAppendInst(SpvOpStore, 0, 0,
165                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
166                       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
167                      newInsts);
168   return true;
169 }
170 
IsConstantIndexAccessChain(const Instruction * acp) const171 bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
172     const Instruction* acp) const {
173   uint32_t inIdx = 0;
174   return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
175     if (inIdx > 0) {
176       Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
177       if (opInst->opcode() != SpvOpConstant) return false;
178     }
179     ++inIdx;
180     return true;
181   });
182 }
183 
HasOnlySupportedRefs(uint32_t ptrId)184 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
185   if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
186   if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
187         if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugValue ||
188             user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugDeclare) {
189           return true;
190         }
191         SpvOp op = user->opcode();
192         if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
193           if (!HasOnlySupportedRefs(user->result_id())) {
194             return false;
195           }
196         } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
197                    !IsNonTypeDecorate(op)) {
198           return false;
199         }
200         return true;
201       })) {
202     supported_ref_ptrs_.insert(ptrId);
203     return true;
204   }
205   return false;
206 }
207 
FindTargetVars(Function * func)208 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
209   for (auto bi = func->begin(); bi != func->end(); ++bi) {
210     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
211       switch (ii->opcode()) {
212         case SpvOpStore:
213         case SpvOpLoad: {
214           uint32_t varId;
215           Instruction* ptrInst = GetPtr(&*ii, &varId);
216           if (!IsTargetVar(varId)) break;
217           const SpvOp op = ptrInst->opcode();
218           // Rule out variables with non-supported refs eg function calls
219           if (!HasOnlySupportedRefs(varId)) {
220             seen_non_target_vars_.insert(varId);
221             seen_target_vars_.erase(varId);
222             break;
223           }
224           // Rule out variables with nested access chains
225           // TODO(): Convert nested access chains
226           if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand(
227                                              kAccessChainPtrIdInIdx) != varId) {
228             seen_non_target_vars_.insert(varId);
229             seen_target_vars_.erase(varId);
230             break;
231           }
232           // Rule out variables accessed with non-constant indices
233           if (!IsConstantIndexAccessChain(ptrInst)) {
234             seen_non_target_vars_.insert(varId);
235             seen_target_vars_.erase(varId);
236             break;
237           }
238         } break;
239         default:
240           break;
241       }
242     }
243   }
244 }
245 
ConvertLocalAccessChains(Function * func)246 Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
247     Function* func) {
248   FindTargetVars(func);
249   // Replace access chains of all targeted variables with equivalent
250   // extract and insert sequences
251   bool modified = false;
252   for (auto bi = func->begin(); bi != func->end(); ++bi) {
253     std::vector<Instruction*> dead_instructions;
254     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
255       switch (ii->opcode()) {
256         case SpvOpLoad: {
257           uint32_t varId;
258           Instruction* ptrInst = GetPtr(&*ii, &varId);
259           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
260           if (!IsTargetVar(varId)) break;
261           if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
262             return Status::Failure;
263           }
264           modified = true;
265         } break;
266         case SpvOpStore: {
267           uint32_t varId;
268           Instruction* store = &*ii;
269           Instruction* ptrInst = GetPtr(store, &varId);
270           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
271           if (!IsTargetVar(varId)) break;
272           std::vector<std::unique_ptr<Instruction>> newInsts;
273           uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
274           if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
275             return Status::Failure;
276           }
277           size_t num_of_instructions_to_skip = newInsts.size() - 1;
278           dead_instructions.push_back(store);
279           ++ii;
280           ii = ii.InsertBefore(std::move(newInsts));
281           for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
282             ii->UpdateDebugInfoFrom(store);
283             context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
284             ++ii;
285           }
286           ii->UpdateDebugInfoFrom(store);
287           context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
288           modified = true;
289         } break;
290         default:
291           break;
292       }
293     }
294 
295     while (!dead_instructions.empty()) {
296       Instruction* inst = dead_instructions.back();
297       dead_instructions.pop_back();
298       DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
299         auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
300                            other_inst);
301         if (i != dead_instructions.end()) {
302           dead_instructions.erase(i);
303         }
304       });
305     }
306   }
307   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
308 }
309 
Initialize()310 void LocalAccessChainConvertPass::Initialize() {
311   // Initialize Target Variable Caches
312   seen_target_vars_.clear();
313   seen_non_target_vars_.clear();
314 
315   // Initialize collections
316   supported_ref_ptrs_.clear();
317 
318   // Initialize extension allowlist
319   InitExtensions();
320 }
321 
AllExtensionsSupported() const322 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
323   // This capability can now exist without the extension, so we have to check
324   // for the capability.  This pass is only looking at function scope symbols,
325   // so we do not care if there are variable pointers on storage buffers.
326   if (context()->get_feature_mgr()->HasCapability(
327           SpvCapabilityVariablePointers))
328     return false;
329   // If any extension not in allowlist, return false
330   for (auto& ei : get_module()->extensions()) {
331     const char* extName =
332         reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
333     if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
334       return false;
335   }
336   return true;
337 }
338 
ProcessImpl()339 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
340   // If non-32-bit integer type in module, terminate processing
341   // TODO(): Handle non-32-bit integer constants in access chains
342   for (const Instruction& inst : get_module()->types_values())
343     if (inst.opcode() == SpvOpTypeInt &&
344         inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
345       return Status::SuccessWithoutChange;
346   // Do not process if module contains OpGroupDecorate. Additional
347   // support required in KillNamesAndDecorates().
348   // TODO(greg-lunarg): Add support for OpGroupDecorate
349   for (auto& ai : get_module()->annotations())
350     if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
351   // Do not process if any disallowed extensions are enabled
352   if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
353 
354   // Process all functions in the module.
355   Status status = Status::SuccessWithoutChange;
356   for (Function& func : *get_module()) {
357     status = CombineStatus(status, ConvertLocalAccessChains(&func));
358     if (status == Status::Failure) {
359       break;
360     }
361   }
362   return status;
363 }
364 
LocalAccessChainConvertPass()365 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
366 
Process()367 Pass::Status LocalAccessChainConvertPass::Process() {
368   Initialize();
369   return ProcessImpl();
370 }
371 
InitExtensions()372 void LocalAccessChainConvertPass::InitExtensions() {
373   extensions_allowlist_.clear();
374   extensions_allowlist_.insert({
375       "SPV_AMD_shader_explicit_vertex_parameter",
376       "SPV_AMD_shader_trinary_minmax",
377       "SPV_AMD_gcn_shader",
378       "SPV_KHR_shader_ballot",
379       "SPV_AMD_shader_ballot",
380       "SPV_AMD_gpu_shader_half_float",
381       "SPV_KHR_shader_draw_parameters",
382       "SPV_KHR_subgroup_vote",
383       "SPV_KHR_8bit_storage",
384       "SPV_KHR_16bit_storage",
385       "SPV_KHR_device_group",
386       "SPV_KHR_multiview",
387       "SPV_NVX_multiview_per_view_attributes",
388       "SPV_NV_viewport_array2",
389       "SPV_NV_stereo_view_rendering",
390       "SPV_NV_sample_mask_override_coverage",
391       "SPV_NV_geometry_shader_passthrough",
392       "SPV_AMD_texture_gather_bias_lod",
393       "SPV_KHR_storage_buffer_storage_class",
394       // SPV_KHR_variable_pointers
395       //   Currently do not support extended pointer expressions
396       "SPV_AMD_gpu_shader_int16",
397       "SPV_KHR_post_depth_coverage",
398       "SPV_KHR_shader_atomic_counter_ops",
399       "SPV_EXT_shader_stencil_export",
400       "SPV_EXT_shader_viewport_index_layer",
401       "SPV_AMD_shader_image_load_store_lod",
402       "SPV_AMD_shader_fragment_mask",
403       "SPV_EXT_fragment_fully_covered",
404       "SPV_AMD_gpu_shader_half_float_fetch",
405       "SPV_GOOGLE_decorate_string",
406       "SPV_GOOGLE_hlsl_functionality1",
407       "SPV_GOOGLE_user_type",
408       "SPV_NV_shader_subgroup_partitioned",
409       "SPV_EXT_demote_to_helper_invocation",
410       "SPV_EXT_descriptor_indexing",
411       "SPV_NV_fragment_shader_barycentric",
412       "SPV_NV_compute_shader_derivatives",
413       "SPV_NV_shader_image_footprint",
414       "SPV_NV_shading_rate",
415       "SPV_NV_mesh_shader",
416       "SPV_NV_ray_tracing",
417       "SPV_KHR_ray_tracing",
418       "SPV_KHR_ray_query",
419       "SPV_EXT_fragment_invocation_density",
420       "SPV_KHR_terminate_invocation",
421   });
422 }
423 
424 }  // namespace opt
425 }  // namespace spvtools
426