• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/local_access_chain_convert_pass.h"
18 
19 #include "ir_context.h"
20 #include "iterator.h"
21 #include "source/util/string_utils.h"
22 
23 namespace spvtools {
24 namespace opt {
25 namespace {
26 constexpr uint32_t kStoreValIdInIdx = 1;
27 constexpr uint32_t kAccessChainPtrIdInIdx = 0;
28 }  // namespace
29 
BuildAndAppendInst(spv::Op opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)30 void LocalAccessChainConvertPass::BuildAndAppendInst(
31     spv::Op opcode, uint32_t typeId, uint32_t resultId,
32     const std::vector<Operand>& in_opnds,
33     std::vector<std::unique_ptr<Instruction>>* newInsts) {
34   std::unique_ptr<Instruction> newInst(
35       new Instruction(context(), opcode, typeId, resultId, in_opnds));
36   get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
37   newInsts->emplace_back(std::move(newInst));
38 }
39 
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)40 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
41     const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
42     std::vector<std::unique_ptr<Instruction>>* newInsts) {
43   const uint32_t ldResultId = TakeNextId();
44   if (ldResultId == 0) {
45     return 0;
46   }
47 
48   *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
49   const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
50   assert(varInst->opcode() == spv::Op::OpVariable);
51   *varPteTypeId = GetPointeeTypeId(varInst);
52   BuildAndAppendInst(spv::Op::OpLoad, *varPteTypeId, ldResultId,
53                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
54                      newInsts);
55   return ldResultId;
56 }
57 
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)58 void LocalAccessChainConvertPass::AppendConstantOperands(
59     const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
60   uint32_t iidIdx = 0;
61   ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
62     if (iidIdx > 0) {
63       const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
64       const auto* constant_value =
65           context()->get_constant_mgr()->GetConstantFromInst(cInst);
66       assert(constant_value != nullptr &&
67              "Expecting the index to be a constant.");
68 
69       // We take the sign extended value because OpAccessChain interprets the
70       // index as signed.
71       int64_t long_value = constant_value->GetSignExtendedValue();
72       assert(long_value <= UINT32_MAX && long_value >= 0 &&
73              "The index value is too large for a composite insert or extract "
74              "instruction.");
75 
76       uint32_t val = static_cast<uint32_t>(long_value);
77       in_opnds->push_back(
78           {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
79     }
80     ++iidIdx;
81   });
82 }
83 
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)84 bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
85     const Instruction* address_inst, Instruction* original_load) {
86   // Build and append load of variable in ptrInst
87   if (address_inst->NumInOperands() == 1) {
88     // An access chain with no indices is essentially a copy.  All that is
89     // needed is to propagate the address.
90     context()->ReplaceAllUsesWith(
91         address_inst->result_id(),
92         address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
93     return true;
94   }
95 
96   std::vector<std::unique_ptr<Instruction>> new_inst;
97   uint32_t varId;
98   uint32_t varPteTypeId;
99   const uint32_t ldResultId =
100       BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
101   if (ldResultId == 0) {
102     return false;
103   }
104 
105   new_inst[0]->UpdateDebugInfoFrom(original_load);
106   context()->get_decoration_mgr()->CloneDecorations(
107       original_load->result_id(), ldResultId,
108       {spv::Decoration::RelaxedPrecision});
109   original_load->InsertBefore(std::move(new_inst));
110   context()->get_debug_info_mgr()->AnalyzeDebugInst(
111       original_load->PreviousNode());
112 
113   // Rewrite |original_load| into an extract.
114   Instruction::OperandList new_operands;
115 
116   // copy the result id and the type id to the new operand list.
117   new_operands.emplace_back(original_load->GetOperand(0));
118   new_operands.emplace_back(original_load->GetOperand(1));
119 
120   new_operands.emplace_back(
121       Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
122   AppendConstantOperands(address_inst, &new_operands);
123   original_load->SetOpcode(spv::Op::OpCompositeExtract);
124   original_load->ReplaceOperands(new_operands);
125   context()->UpdateDefUse(original_load);
126   return true;
127 }
128 
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)129 bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
130     const Instruction* ptrInst, uint32_t valId,
131     std::vector<std::unique_ptr<Instruction>>* newInsts) {
132   if (ptrInst->NumInOperands() == 1) {
133     // An access chain with no indices is essentially a copy.  However, we still
134     // have to create a new store because the old ones will be deleted.
135     BuildAndAppendInst(
136         spv::Op::OpStore, 0, 0,
137         {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
138           {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
139          {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
140         newInsts);
141     return true;
142   }
143 
144   // Build and append load of variable in ptrInst
145   uint32_t varId;
146   uint32_t varPteTypeId;
147   const uint32_t ldResultId =
148       BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
149   if (ldResultId == 0) {
150     return false;
151   }
152 
153   context()->get_decoration_mgr()->CloneDecorations(
154       varId, ldResultId, {spv::Decoration::RelaxedPrecision});
155 
156   // Build and append Insert
157   const uint32_t insResultId = TakeNextId();
158   if (insResultId == 0) {
159     return false;
160   }
161   std::vector<Operand> ins_in_opnds = {
162       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
163       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
164   AppendConstantOperands(ptrInst, &ins_in_opnds);
165   BuildAndAppendInst(spv::Op::OpCompositeInsert, varPteTypeId, insResultId,
166                      ins_in_opnds, newInsts);
167 
168   context()->get_decoration_mgr()->CloneDecorations(
169       varId, insResultId, {spv::Decoration::RelaxedPrecision});
170 
171   // Build and append Store
172   BuildAndAppendInst(spv::Op::OpStore, 0, 0,
173                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
174                       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
175                      newInsts);
176   return true;
177 }
178 
Is32BitConstantIndexAccessChain(const Instruction * acp) const179 bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
180     const Instruction* acp) const {
181   uint32_t inIdx = 0;
182   return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
183     if (inIdx > 0) {
184       Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
185       if (opInst->opcode() != spv::Op::OpConstant) return false;
186       const auto* index =
187           context()->get_constant_mgr()->GetConstantFromInst(opInst);
188       int64_t index_value = index->GetSignExtendedValue();
189       if (index_value > UINT32_MAX) return false;
190       if (index_value < 0) return false;
191     }
192     ++inIdx;
193     return true;
194   });
195 }
196 
HasOnlySupportedRefs(uint32_t ptrId)197 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
198   if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
199   if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
200         if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue ||
201             user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
202           return true;
203         }
204         spv::Op op = user->opcode();
205         if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
206           if (!HasOnlySupportedRefs(user->result_id())) {
207             return false;
208           }
209         } else if (op != spv::Op::OpStore && op != spv::Op::OpLoad &&
210                    op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
211           return false;
212         }
213         return true;
214       })) {
215     supported_ref_ptrs_.insert(ptrId);
216     return true;
217   }
218   return false;
219 }
220 
FindTargetVars(Function * func)221 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
222   for (auto bi = func->begin(); bi != func->end(); ++bi) {
223     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
224       switch (ii->opcode()) {
225         case spv::Op::OpStore:
226         case spv::Op::OpLoad: {
227           uint32_t varId;
228           Instruction* ptrInst = GetPtr(&*ii, &varId);
229           if (!IsTargetVar(varId)) break;
230           const spv::Op op = ptrInst->opcode();
231           // Rule out variables with non-supported refs eg function calls
232           if (!HasOnlySupportedRefs(varId)) {
233             seen_non_target_vars_.insert(varId);
234             seen_target_vars_.erase(varId);
235             break;
236           }
237           // Rule out variables with nested access chains
238           // TODO(): Convert nested access chains
239           bool is_non_ptr_access_chain = IsNonPtrAccessChain(op);
240           if (is_non_ptr_access_chain && ptrInst->GetSingleWordInOperand(
241                                              kAccessChainPtrIdInIdx) != varId) {
242             seen_non_target_vars_.insert(varId);
243             seen_target_vars_.erase(varId);
244             break;
245           }
246           // Rule out variables accessed with non-constant indices
247           if (!Is32BitConstantIndexAccessChain(ptrInst)) {
248             seen_non_target_vars_.insert(varId);
249             seen_target_vars_.erase(varId);
250             break;
251           }
252 
253           if (is_non_ptr_access_chain && AnyIndexIsOutOfBounds(ptrInst)) {
254             seen_non_target_vars_.insert(varId);
255             seen_target_vars_.erase(varId);
256             break;
257           }
258         } break;
259         default:
260           break;
261       }
262     }
263   }
264 }
265 
ConvertLocalAccessChains(Function * func)266 Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
267     Function* func) {
268   FindTargetVars(func);
269   // Replace access chains of all targeted variables with equivalent
270   // extract and insert sequences
271   bool modified = false;
272   for (auto bi = func->begin(); bi != func->end(); ++bi) {
273     std::vector<Instruction*> dead_instructions;
274     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
275       switch (ii->opcode()) {
276         case spv::Op::OpLoad: {
277           uint32_t varId;
278           Instruction* ptrInst = GetPtr(&*ii, &varId);
279           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
280           if (!IsTargetVar(varId)) break;
281           if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
282             return Status::Failure;
283           }
284           modified = true;
285         } break;
286         case spv::Op::OpStore: {
287           uint32_t varId;
288           Instruction* store = &*ii;
289           Instruction* ptrInst = GetPtr(store, &varId);
290           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
291           if (!IsTargetVar(varId)) break;
292           std::vector<std::unique_ptr<Instruction>> newInsts;
293           uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
294           if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
295             return Status::Failure;
296           }
297           size_t num_of_instructions_to_skip = newInsts.size() - 1;
298           dead_instructions.push_back(store);
299           ++ii;
300           ii = ii.InsertBefore(std::move(newInsts));
301           for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
302             ii->UpdateDebugInfoFrom(store);
303             context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
304             ++ii;
305           }
306           ii->UpdateDebugInfoFrom(store);
307           context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
308           modified = true;
309         } break;
310         default:
311           break;
312       }
313     }
314 
315     while (!dead_instructions.empty()) {
316       Instruction* inst = dead_instructions.back();
317       dead_instructions.pop_back();
318       DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
319         auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
320                            other_inst);
321         if (i != dead_instructions.end()) {
322           dead_instructions.erase(i);
323         }
324       });
325     }
326   }
327   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
328 }
329 
Initialize()330 void LocalAccessChainConvertPass::Initialize() {
331   // Initialize Target Variable Caches
332   seen_target_vars_.clear();
333   seen_non_target_vars_.clear();
334 
335   // Initialize collections
336   supported_ref_ptrs_.clear();
337 
338   // Initialize extension allowlist
339   InitExtensions();
340 }
341 
AllExtensionsSupported() const342 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
343   // This capability can now exist without the extension, so we have to check
344   // for the capability.  This pass is only looking at function scope symbols,
345   // so we do not care if there are variable pointers on storage buffers.
346   if (context()->get_feature_mgr()->HasCapability(
347           spv::Capability::VariablePointers))
348     return false;
349   // If any extension not in allowlist, return false
350   for (auto& ei : get_module()->extensions()) {
351     const std::string extName = ei.GetInOperand(0).AsString();
352     if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
353       return false;
354   }
355   // only allow NonSemantic.Shader.DebugInfo.100, we cannot safely optimise
356   // around unknown extended
357   // instruction sets even if they are non-semantic
358   for (auto& inst : context()->module()->ext_inst_imports()) {
359     assert(inst.opcode() == spv::Op::OpExtInstImport &&
360            "Expecting an import of an extension's instruction set.");
361     const std::string extension_name = inst.GetInOperand(0).AsString();
362     if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
363         extension_name != "NonSemantic.Shader.DebugInfo.100") {
364       return false;
365     }
366   }
367   return true;
368 }
369 
ProcessImpl()370 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
371   // Do not process if module contains OpGroupDecorate. Additional
372   // support required in KillNamesAndDecorates().
373   // TODO(greg-lunarg): Add support for OpGroupDecorate
374   for (auto& ai : get_module()->annotations())
375     if (ai.opcode() == spv::Op::OpGroupDecorate)
376       return Status::SuccessWithoutChange;
377   // Do not process if any disallowed extensions are enabled
378   if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
379 
380   // Process all functions in the module.
381   Status status = Status::SuccessWithoutChange;
382   for (Function& func : *get_module()) {
383     status = CombineStatus(status, ConvertLocalAccessChains(&func));
384     if (status == Status::Failure) {
385       break;
386     }
387   }
388   return status;
389 }
390 
LocalAccessChainConvertPass()391 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
392 
Process()393 Pass::Status LocalAccessChainConvertPass::Process() {
394   Initialize();
395   return ProcessImpl();
396 }
397 
InitExtensions()398 void LocalAccessChainConvertPass::InitExtensions() {
399   extensions_allowlist_.clear();
400   extensions_allowlist_.insert(
401       {"SPV_AMD_shader_explicit_vertex_parameter",
402        "SPV_AMD_shader_trinary_minmax", "SPV_AMD_gcn_shader",
403        "SPV_KHR_shader_ballot", "SPV_AMD_shader_ballot",
404        "SPV_AMD_gpu_shader_half_float", "SPV_KHR_shader_draw_parameters",
405        "SPV_KHR_subgroup_vote", "SPV_KHR_8bit_storage", "SPV_KHR_16bit_storage",
406        "SPV_KHR_device_group", "SPV_KHR_multiview",
407        "SPV_NVX_multiview_per_view_attributes", "SPV_NV_viewport_array2",
408        "SPV_NV_stereo_view_rendering", "SPV_NV_sample_mask_override_coverage",
409        "SPV_NV_geometry_shader_passthrough", "SPV_AMD_texture_gather_bias_lod",
410        "SPV_KHR_storage_buffer_storage_class",
411        // SPV_KHR_variable_pointers
412        //   Currently do not support extended pointer expressions
413        "SPV_AMD_gpu_shader_int16", "SPV_KHR_post_depth_coverage",
414        "SPV_KHR_shader_atomic_counter_ops", "SPV_EXT_shader_stencil_export",
415        "SPV_EXT_shader_viewport_index_layer",
416        "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask",
417        "SPV_EXT_fragment_fully_covered", "SPV_AMD_gpu_shader_half_float_fetch",
418        "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1",
419        "SPV_GOOGLE_user_type", "SPV_NV_shader_subgroup_partitioned",
420        "SPV_EXT_demote_to_helper_invocation", "SPV_EXT_descriptor_indexing",
421        "SPV_NV_fragment_shader_barycentric",
422        "SPV_NV_compute_shader_derivatives", "SPV_NV_shader_image_footprint",
423        "SPV_NV_shading_rate", "SPV_NV_mesh_shader", "SPV_NV_ray_tracing",
424        "SPV_KHR_ray_tracing", "SPV_KHR_ray_query",
425        "SPV_EXT_fragment_invocation_density", "SPV_KHR_terminate_invocation",
426        "SPV_KHR_subgroup_uniform_control_flow", "SPV_KHR_integer_dot_product",
427        "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info",
428        "SPV_KHR_uniform_group_instructions",
429        "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
430        "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
431        "SPV_EXT_fragment_shader_interlock",
432        "SPV_NV_compute_shader_derivatives"});
433 }
434 
AnyIndexIsOutOfBounds(const Instruction * access_chain_inst)435 bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
436     const Instruction* access_chain_inst) {
437   assert(IsNonPtrAccessChain(access_chain_inst->opcode()));
438 
439   analysis::TypeManager* type_mgr = context()->get_type_mgr();
440   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
441   auto constants = const_mgr->GetOperandConstants(access_chain_inst);
442   uint32_t base_pointer_id = access_chain_inst->GetSingleWordInOperand(0);
443   Instruction* base_pointer = get_def_use_mgr()->GetDef(base_pointer_id);
444   const analysis::Pointer* base_pointer_type =
445       type_mgr->GetType(base_pointer->type_id())->AsPointer();
446   assert(base_pointer_type != nullptr &&
447          "The base of the access chain is not a pointer.");
448   const analysis::Type* current_type = base_pointer_type->pointee_type();
449   for (uint32_t i = 1; i < access_chain_inst->NumInOperands(); ++i) {
450     if (IsIndexOutOfBounds(constants[i], current_type)) {
451       return true;
452     }
453 
454     uint32_t index =
455         (constants[i]
456              ? static_cast<uint32_t>(constants[i]->GetZeroExtendedValue())
457              : 0);
458     current_type = type_mgr->GetMemberType(current_type, {index});
459   }
460 
461   return false;
462 }
463 
IsIndexOutOfBounds(const analysis::Constant * index,const analysis::Type * type) const464 bool LocalAccessChainConvertPass::IsIndexOutOfBounds(
465     const analysis::Constant* index, const analysis::Type* type) const {
466   if (index == nullptr) {
467     return false;
468   }
469   return index->GetZeroExtendedValue() >= type->NumberOfComponents();
470 }
471 
472 }  // namespace opt
473 }  // namespace spvtools
474