• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/local_access_chain_convert_pass.h"
18 
19 #include "ir_builder.h"
20 #include "ir_context.h"
21 #include "iterator.h"
22 #include "source/util/string_utils.h"
23 
24 namespace spvtools {
25 namespace opt {
26 
27 namespace {
28 
29 const uint32_t kStoreValIdInIdx = 1;
30 const uint32_t kAccessChainPtrIdInIdx = 0;
31 
32 }  // anonymous namespace
33 
BuildAndAppendInst(SpvOp opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)34 void LocalAccessChainConvertPass::BuildAndAppendInst(
35     SpvOp opcode, uint32_t typeId, uint32_t resultId,
36     const std::vector<Operand>& in_opnds,
37     std::vector<std::unique_ptr<Instruction>>* newInsts) {
38   std::unique_ptr<Instruction> newInst(
39       new Instruction(context(), opcode, typeId, resultId, in_opnds));
40   get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
41   newInsts->emplace_back(std::move(newInst));
42 }
43 
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)44 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
45     const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
46     std::vector<std::unique_ptr<Instruction>>* newInsts) {
47   const uint32_t ldResultId = TakeNextId();
48   if (ldResultId == 0) {
49     return 0;
50   }
51 
52   *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
53   const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
54   assert(varInst->opcode() == SpvOpVariable);
55   *varPteTypeId = GetPointeeTypeId(varInst);
56   BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
57                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
58                      newInsts);
59   return ldResultId;
60 }
61 
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)62 void LocalAccessChainConvertPass::AppendConstantOperands(
63     const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
64   uint32_t iidIdx = 0;
65   ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
66     if (iidIdx > 0) {
67       const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
68       const auto* constant_value =
69           context()->get_constant_mgr()->GetConstantFromInst(cInst);
70       assert(constant_value != nullptr &&
71              "Expecting the index to be a constant.");
72 
73       // We take the sign extended value because OpAccessChain interprets the
74       // index as signed.
75       int64_t long_value = constant_value->GetSignExtendedValue();
76       assert(long_value <= UINT32_MAX && long_value >= 0 &&
77              "The index value is too large for a composite insert or extract "
78              "instruction.");
79 
80       uint32_t val = static_cast<uint32_t>(long_value);
81       in_opnds->push_back(
82           {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
83     }
84     ++iidIdx;
85   });
86 }
87 
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)88 bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
89     const Instruction* address_inst, Instruction* original_load) {
90   // Build and append load of variable in ptrInst
91   if (address_inst->NumInOperands() == 1) {
92     // An access chain with no indices is essentially a copy.  All that is
93     // needed is to propagate the address.
94     context()->ReplaceAllUsesWith(
95         address_inst->result_id(),
96         address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
97     return true;
98   }
99 
100   std::vector<std::unique_ptr<Instruction>> new_inst;
101   uint32_t varId;
102   uint32_t varPteTypeId;
103   const uint32_t ldResultId =
104       BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
105   if (ldResultId == 0) {
106     return false;
107   }
108 
109   new_inst[0]->UpdateDebugInfoFrom(original_load);
110   context()->get_decoration_mgr()->CloneDecorations(
111       original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision});
112   original_load->InsertBefore(std::move(new_inst));
113   context()->get_debug_info_mgr()->AnalyzeDebugInst(
114       original_load->PreviousNode());
115 
116   // Rewrite |original_load| into an extract.
117   Instruction::OperandList new_operands;
118 
119   // copy the result id and the type id to the new operand list.
120   new_operands.emplace_back(original_load->GetOperand(0));
121   new_operands.emplace_back(original_load->GetOperand(1));
122 
123   new_operands.emplace_back(
124       Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
125   AppendConstantOperands(address_inst, &new_operands);
126   original_load->SetOpcode(SpvOpCompositeExtract);
127   original_load->ReplaceOperands(new_operands);
128   context()->UpdateDefUse(original_load);
129   return true;
130 }
131 
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)132 bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
133     const Instruction* ptrInst, uint32_t valId,
134     std::vector<std::unique_ptr<Instruction>>* newInsts) {
135   if (ptrInst->NumInOperands() == 1) {
136     // An access chain with no indices is essentially a copy.  However, we still
137     // have to create a new store because the old ones will be deleted.
138     BuildAndAppendInst(
139         SpvOpStore, 0, 0,
140         {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
141           {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
142          {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
143         newInsts);
144     return true;
145   }
146 
147   // Build and append load of variable in ptrInst
148   uint32_t varId;
149   uint32_t varPteTypeId;
150   const uint32_t ldResultId =
151       BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
152   if (ldResultId == 0) {
153     return false;
154   }
155 
156   context()->get_decoration_mgr()->CloneDecorations(
157       varId, ldResultId, {SpvDecorationRelaxedPrecision});
158 
159   // Build and append Insert
160   const uint32_t insResultId = TakeNextId();
161   if (insResultId == 0) {
162     return false;
163   }
164   std::vector<Operand> ins_in_opnds = {
165       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
166       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
167   AppendConstantOperands(ptrInst, &ins_in_opnds);
168   BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId,
169                      ins_in_opnds, newInsts);
170 
171   context()->get_decoration_mgr()->CloneDecorations(
172       varId, insResultId, {SpvDecorationRelaxedPrecision});
173 
174   // Build and append Store
175   BuildAndAppendInst(SpvOpStore, 0, 0,
176                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
177                       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
178                      newInsts);
179   return true;
180 }
181 
Is32BitConstantIndexAccessChain(const Instruction * acp) const182 bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
183     const Instruction* acp) const {
184   uint32_t inIdx = 0;
185   return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
186     if (inIdx > 0) {
187       Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
188       if (opInst->opcode() != SpvOpConstant) return false;
189       const auto* index =
190           context()->get_constant_mgr()->GetConstantFromInst(opInst);
191       int64_t index_value = index->GetSignExtendedValue();
192       if (index_value > UINT32_MAX) return false;
193       if (index_value < 0) return false;
194     }
195     ++inIdx;
196     return true;
197   });
198 }
199 
HasOnlySupportedRefs(uint32_t ptrId)200 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
201   if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
202   if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
203         if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue ||
204             user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
205           return true;
206         }
207         SpvOp op = user->opcode();
208         if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
209           if (!HasOnlySupportedRefs(user->result_id())) {
210             return false;
211           }
212         } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
213                    !IsNonTypeDecorate(op)) {
214           return false;
215         }
216         return true;
217       })) {
218     supported_ref_ptrs_.insert(ptrId);
219     return true;
220   }
221   return false;
222 }
223 
FindTargetVars(Function * func)224 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
225   for (auto bi = func->begin(); bi != func->end(); ++bi) {
226     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
227       switch (ii->opcode()) {
228         case SpvOpStore:
229         case SpvOpLoad: {
230           uint32_t varId;
231           Instruction* ptrInst = GetPtr(&*ii, &varId);
232           if (!IsTargetVar(varId)) break;
233           const SpvOp op = ptrInst->opcode();
234           // Rule out variables with non-supported refs eg function calls
235           if (!HasOnlySupportedRefs(varId)) {
236             seen_non_target_vars_.insert(varId);
237             seen_target_vars_.erase(varId);
238             break;
239           }
240           // Rule out variables with nested access chains
241           // TODO(): Convert nested access chains
242           bool is_non_ptr_access_chain = IsNonPtrAccessChain(op);
243           if (is_non_ptr_access_chain && ptrInst->GetSingleWordInOperand(
244                                              kAccessChainPtrIdInIdx) != varId) {
245             seen_non_target_vars_.insert(varId);
246             seen_target_vars_.erase(varId);
247             break;
248           }
249           // Rule out variables accessed with non-constant indices
250           if (!Is32BitConstantIndexAccessChain(ptrInst)) {
251             seen_non_target_vars_.insert(varId);
252             seen_target_vars_.erase(varId);
253             break;
254           }
255 
256           if (is_non_ptr_access_chain && AnyIndexIsOutOfBounds(ptrInst)) {
257             seen_non_target_vars_.insert(varId);
258             seen_target_vars_.erase(varId);
259             break;
260           }
261         } break;
262         default:
263           break;
264       }
265     }
266   }
267 }
268 
ConvertLocalAccessChains(Function * func)269 Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
270     Function* func) {
271   FindTargetVars(func);
272   // Replace access chains of all targeted variables with equivalent
273   // extract and insert sequences
274   bool modified = false;
275   for (auto bi = func->begin(); bi != func->end(); ++bi) {
276     std::vector<Instruction*> dead_instructions;
277     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
278       switch (ii->opcode()) {
279         case SpvOpLoad: {
280           uint32_t varId;
281           Instruction* ptrInst = GetPtr(&*ii, &varId);
282           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
283           if (!IsTargetVar(varId)) break;
284           if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
285             return Status::Failure;
286           }
287           modified = true;
288         } break;
289         case SpvOpStore: {
290           uint32_t varId;
291           Instruction* store = &*ii;
292           Instruction* ptrInst = GetPtr(store, &varId);
293           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
294           if (!IsTargetVar(varId)) break;
295           std::vector<std::unique_ptr<Instruction>> newInsts;
296           uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
297           if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
298             return Status::Failure;
299           }
300           size_t num_of_instructions_to_skip = newInsts.size() - 1;
301           dead_instructions.push_back(store);
302           ++ii;
303           ii = ii.InsertBefore(std::move(newInsts));
304           for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
305             ii->UpdateDebugInfoFrom(store);
306             context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
307             ++ii;
308           }
309           ii->UpdateDebugInfoFrom(store);
310           context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
311           modified = true;
312         } break;
313         default:
314           break;
315       }
316     }
317 
318     while (!dead_instructions.empty()) {
319       Instruction* inst = dead_instructions.back();
320       dead_instructions.pop_back();
321       DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
322         auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
323                            other_inst);
324         if (i != dead_instructions.end()) {
325           dead_instructions.erase(i);
326         }
327       });
328     }
329   }
330   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
331 }
332 
Initialize()333 void LocalAccessChainConvertPass::Initialize() {
334   // Initialize Target Variable Caches
335   seen_target_vars_.clear();
336   seen_non_target_vars_.clear();
337 
338   // Initialize collections
339   supported_ref_ptrs_.clear();
340 
341   // Initialize extension allowlist
342   InitExtensions();
343 }
344 
AllExtensionsSupported() const345 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
346   // This capability can now exist without the extension, so we have to check
347   // for the capability.  This pass is only looking at function scope symbols,
348   // so we do not care if there are variable pointers on storage buffers.
349   if (context()->get_feature_mgr()->HasCapability(
350           SpvCapabilityVariablePointers))
351     return false;
352   // If any extension not in allowlist, return false
353   for (auto& ei : get_module()->extensions()) {
354     const std::string extName = ei.GetInOperand(0).AsString();
355     if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
356       return false;
357   }
358   // only allow NonSemantic.Shader.DebugInfo.100, we cannot safely optimise
359   // around unknown extended
360   // instruction sets even if they are non-semantic
361   for (auto& inst : context()->module()->ext_inst_imports()) {
362     assert(inst.opcode() == SpvOpExtInstImport &&
363            "Expecting an import of an extension's instruction set.");
364     const std::string extension_name = inst.GetInOperand(0).AsString();
365     if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
366         extension_name != "NonSemantic.Shader.DebugInfo.100") {
367       return false;
368     }
369   }
370   return true;
371 }
372 
ProcessImpl()373 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
374   // Do not process if module contains OpGroupDecorate. Additional
375   // support required in KillNamesAndDecorates().
376   // TODO(greg-lunarg): Add support for OpGroupDecorate
377   for (auto& ai : get_module()->annotations())
378     if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
379   // Do not process if any disallowed extensions are enabled
380   if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
381 
382   // Process all functions in the module.
383   Status status = Status::SuccessWithoutChange;
384   for (Function& func : *get_module()) {
385     status = CombineStatus(status, ConvertLocalAccessChains(&func));
386     if (status == Status::Failure) {
387       break;
388     }
389   }
390   return status;
391 }
392 
LocalAccessChainConvertPass()393 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
394 
Process()395 Pass::Status LocalAccessChainConvertPass::Process() {
396   Initialize();
397   return ProcessImpl();
398 }
399 
InitExtensions()400 void LocalAccessChainConvertPass::InitExtensions() {
401   extensions_allowlist_.clear();
402   extensions_allowlist_.insert({
403       "SPV_AMD_shader_explicit_vertex_parameter",
404       "SPV_AMD_shader_trinary_minmax",
405       "SPV_AMD_gcn_shader",
406       "SPV_KHR_shader_ballot",
407       "SPV_AMD_shader_ballot",
408       "SPV_AMD_gpu_shader_half_float",
409       "SPV_KHR_shader_draw_parameters",
410       "SPV_KHR_subgroup_vote",
411       "SPV_KHR_8bit_storage",
412       "SPV_KHR_16bit_storage",
413       "SPV_KHR_device_group",
414       "SPV_KHR_multiview",
415       "SPV_NVX_multiview_per_view_attributes",
416       "SPV_NV_viewport_array2",
417       "SPV_NV_stereo_view_rendering",
418       "SPV_NV_sample_mask_override_coverage",
419       "SPV_NV_geometry_shader_passthrough",
420       "SPV_AMD_texture_gather_bias_lod",
421       "SPV_KHR_storage_buffer_storage_class",
422       // SPV_KHR_variable_pointers
423       //   Currently do not support extended pointer expressions
424       "SPV_AMD_gpu_shader_int16",
425       "SPV_KHR_post_depth_coverage",
426       "SPV_KHR_shader_atomic_counter_ops",
427       "SPV_EXT_shader_stencil_export",
428       "SPV_EXT_shader_viewport_index_layer",
429       "SPV_AMD_shader_image_load_store_lod",
430       "SPV_AMD_shader_fragment_mask",
431       "SPV_EXT_fragment_fully_covered",
432       "SPV_AMD_gpu_shader_half_float_fetch",
433       "SPV_GOOGLE_decorate_string",
434       "SPV_GOOGLE_hlsl_functionality1",
435       "SPV_GOOGLE_user_type",
436       "SPV_NV_shader_subgroup_partitioned",
437       "SPV_EXT_demote_to_helper_invocation",
438       "SPV_EXT_descriptor_indexing",
439       "SPV_NV_fragment_shader_barycentric",
440       "SPV_NV_compute_shader_derivatives",
441       "SPV_NV_shader_image_footprint",
442       "SPV_NV_shading_rate",
443       "SPV_NV_mesh_shader",
444       "SPV_NV_ray_tracing",
445       "SPV_KHR_ray_tracing",
446       "SPV_KHR_ray_query",
447       "SPV_EXT_fragment_invocation_density",
448       "SPV_KHR_terminate_invocation",
449       "SPV_KHR_subgroup_uniform_control_flow",
450       "SPV_KHR_integer_dot_product",
451       "SPV_EXT_shader_image_int64",
452       "SPV_KHR_non_semantic_info",
453       "SPV_KHR_uniform_group_instructions",
454       "SPV_KHR_fragment_shader_barycentric",
455   });
456 }
457 
AnyIndexIsOutOfBounds(const Instruction * access_chain_inst)458 bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
459     const Instruction* access_chain_inst) {
460   assert(IsNonPtrAccessChain(access_chain_inst->opcode()));
461 
462   analysis::TypeManager* type_mgr = context()->get_type_mgr();
463   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
464   auto constants = const_mgr->GetOperandConstants(access_chain_inst);
465   uint32_t base_pointer_id = access_chain_inst->GetSingleWordInOperand(0);
466   Instruction* base_pointer = get_def_use_mgr()->GetDef(base_pointer_id);
467   const analysis::Pointer* base_pointer_type =
468       type_mgr->GetType(base_pointer->type_id())->AsPointer();
469   assert(base_pointer_type != nullptr &&
470          "The base of the access chain is not a pointer.");
471   const analysis::Type* current_type = base_pointer_type->pointee_type();
472   for (uint32_t i = 1; i < access_chain_inst->NumInOperands(); ++i) {
473     if (IsIndexOutOfBounds(constants[i], current_type)) {
474       return true;
475     }
476 
477     uint32_t index =
478         (constants[i]
479              ? static_cast<uint32_t>(constants[i]->GetZeroExtendedValue())
480              : 0);
481     current_type = type_mgr->GetMemberType(current_type, {index});
482   }
483 
484   return false;
485 }
486 
IsIndexOutOfBounds(const analysis::Constant * index,const analysis::Type * type) const487 bool LocalAccessChainConvertPass::IsIndexOutOfBounds(
488     const analysis::Constant* index, const analysis::Type* type) const {
489   if (index == nullptr) {
490     return false;
491   }
492   return index->GetZeroExtendedValue() >= type->NumberOfComponents();
493 }
494 
495 }  // namespace opt
496 }  // namespace spvtools
497