1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "source/opt/local_access_chain_convert_pass.h"
18
19 #include "ir_context.h"
20 #include "iterator.h"
21 #include "source/util/string_utils.h"
22
23 namespace spvtools {
24 namespace opt {
25 namespace {
26 constexpr uint32_t kStoreValIdInIdx = 1;
27 constexpr uint32_t kAccessChainPtrIdInIdx = 0;
28 } // namespace
29
BuildAndAppendInst(spv::Op opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)30 void LocalAccessChainConvertPass::BuildAndAppendInst(
31 spv::Op opcode, uint32_t typeId, uint32_t resultId,
32 const std::vector<Operand>& in_opnds,
33 std::vector<std::unique_ptr<Instruction>>* newInsts) {
34 std::unique_ptr<Instruction> newInst(
35 new Instruction(context(), opcode, typeId, resultId, in_opnds));
36 get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
37 newInsts->emplace_back(std::move(newInst));
38 }
39
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)40 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
41 const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
42 std::vector<std::unique_ptr<Instruction>>* newInsts) {
43 const uint32_t ldResultId = TakeNextId();
44 if (ldResultId == 0) {
45 return 0;
46 }
47
48 *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
49 const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
50 assert(varInst->opcode() == spv::Op::OpVariable);
51 *varPteTypeId = GetPointeeTypeId(varInst);
52 BuildAndAppendInst(spv::Op::OpLoad, *varPteTypeId, ldResultId,
53 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
54 newInsts);
55 return ldResultId;
56 }
57
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)58 void LocalAccessChainConvertPass::AppendConstantOperands(
59 const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
60 uint32_t iidIdx = 0;
61 ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
62 if (iidIdx > 0) {
63 const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
64 const auto* constant_value =
65 context()->get_constant_mgr()->GetConstantFromInst(cInst);
66 assert(constant_value != nullptr &&
67 "Expecting the index to be a constant.");
68
69 // We take the sign extended value because OpAccessChain interprets the
70 // index as signed.
71 int64_t long_value = constant_value->GetSignExtendedValue();
72 assert(long_value <= UINT32_MAX && long_value >= 0 &&
73 "The index value is too large for a composite insert or extract "
74 "instruction.");
75
76 uint32_t val = static_cast<uint32_t>(long_value);
77 in_opnds->push_back(
78 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
79 }
80 ++iidIdx;
81 });
82 }
83
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)84 bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
85 const Instruction* address_inst, Instruction* original_load) {
86 // Build and append load of variable in ptrInst
87 if (address_inst->NumInOperands() == 1) {
88 // An access chain with no indices is essentially a copy. All that is
89 // needed is to propagate the address.
90 context()->ReplaceAllUsesWith(
91 address_inst->result_id(),
92 address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
93 return true;
94 }
95
96 std::vector<std::unique_ptr<Instruction>> new_inst;
97 uint32_t varId;
98 uint32_t varPteTypeId;
99 const uint32_t ldResultId =
100 BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
101 if (ldResultId == 0) {
102 return false;
103 }
104
105 new_inst[0]->UpdateDebugInfoFrom(original_load);
106 context()->get_decoration_mgr()->CloneDecorations(
107 original_load->result_id(), ldResultId,
108 {spv::Decoration::RelaxedPrecision});
109 original_load->InsertBefore(std::move(new_inst));
110 context()->get_debug_info_mgr()->AnalyzeDebugInst(
111 original_load->PreviousNode());
112
113 // Rewrite |original_load| into an extract.
114 Instruction::OperandList new_operands;
115
116 // copy the result id and the type id to the new operand list.
117 new_operands.emplace_back(original_load->GetOperand(0));
118 new_operands.emplace_back(original_load->GetOperand(1));
119
120 new_operands.emplace_back(
121 Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
122 AppendConstantOperands(address_inst, &new_operands);
123 original_load->SetOpcode(spv::Op::OpCompositeExtract);
124 original_load->ReplaceOperands(new_operands);
125 context()->UpdateDefUse(original_load);
126 return true;
127 }
128
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)129 bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
130 const Instruction* ptrInst, uint32_t valId,
131 std::vector<std::unique_ptr<Instruction>>* newInsts) {
132 if (ptrInst->NumInOperands() == 1) {
133 // An access chain with no indices is essentially a copy. However, we still
134 // have to create a new store because the old ones will be deleted.
135 BuildAndAppendInst(
136 spv::Op::OpStore, 0, 0,
137 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
138 {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
139 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
140 newInsts);
141 return true;
142 }
143
144 // Build and append load of variable in ptrInst
145 uint32_t varId;
146 uint32_t varPteTypeId;
147 const uint32_t ldResultId =
148 BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
149 if (ldResultId == 0) {
150 return false;
151 }
152
153 context()->get_decoration_mgr()->CloneDecorations(
154 varId, ldResultId, {spv::Decoration::RelaxedPrecision});
155
156 // Build and append Insert
157 const uint32_t insResultId = TakeNextId();
158 if (insResultId == 0) {
159 return false;
160 }
161 std::vector<Operand> ins_in_opnds = {
162 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
163 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
164 AppendConstantOperands(ptrInst, &ins_in_opnds);
165 BuildAndAppendInst(spv::Op::OpCompositeInsert, varPteTypeId, insResultId,
166 ins_in_opnds, newInsts);
167
168 context()->get_decoration_mgr()->CloneDecorations(
169 varId, insResultId, {spv::Decoration::RelaxedPrecision});
170
171 // Build and append Store
172 BuildAndAppendInst(spv::Op::OpStore, 0, 0,
173 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
174 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
175 newInsts);
176 return true;
177 }
178
Is32BitConstantIndexAccessChain(const Instruction * acp) const179 bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
180 const Instruction* acp) const {
181 uint32_t inIdx = 0;
182 return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
183 if (inIdx > 0) {
184 Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
185 if (opInst->opcode() != spv::Op::OpConstant) return false;
186 const auto* index =
187 context()->get_constant_mgr()->GetConstantFromInst(opInst);
188 int64_t index_value = index->GetSignExtendedValue();
189 if (index_value > UINT32_MAX) return false;
190 if (index_value < 0) return false;
191 }
192 ++inIdx;
193 return true;
194 });
195 }
196
HasOnlySupportedRefs(uint32_t ptrId)197 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
198 if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
199 if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
200 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue ||
201 user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
202 return true;
203 }
204 spv::Op op = user->opcode();
205 if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
206 if (!HasOnlySupportedRefs(user->result_id())) {
207 return false;
208 }
209 } else if (op != spv::Op::OpStore && op != spv::Op::OpLoad &&
210 op != spv::Op::OpName && !IsNonTypeDecorate(op)) {
211 return false;
212 }
213 return true;
214 })) {
215 supported_ref_ptrs_.insert(ptrId);
216 return true;
217 }
218 return false;
219 }
220
FindTargetVars(Function * func)221 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
222 for (auto bi = func->begin(); bi != func->end(); ++bi) {
223 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
224 switch (ii->opcode()) {
225 case spv::Op::OpStore:
226 case spv::Op::OpLoad: {
227 uint32_t varId;
228 Instruction* ptrInst = GetPtr(&*ii, &varId);
229 if (!IsTargetVar(varId)) break;
230 const spv::Op op = ptrInst->opcode();
231 // Rule out variables with non-supported refs eg function calls
232 if (!HasOnlySupportedRefs(varId)) {
233 seen_non_target_vars_.insert(varId);
234 seen_target_vars_.erase(varId);
235 break;
236 }
237 // Rule out variables with nested access chains
238 // TODO(): Convert nested access chains
239 bool is_non_ptr_access_chain = IsNonPtrAccessChain(op);
240 if (is_non_ptr_access_chain && ptrInst->GetSingleWordInOperand(
241 kAccessChainPtrIdInIdx) != varId) {
242 seen_non_target_vars_.insert(varId);
243 seen_target_vars_.erase(varId);
244 break;
245 }
246 // Rule out variables accessed with non-constant indices
247 if (!Is32BitConstantIndexAccessChain(ptrInst)) {
248 seen_non_target_vars_.insert(varId);
249 seen_target_vars_.erase(varId);
250 break;
251 }
252
253 if (is_non_ptr_access_chain && AnyIndexIsOutOfBounds(ptrInst)) {
254 seen_non_target_vars_.insert(varId);
255 seen_target_vars_.erase(varId);
256 break;
257 }
258 } break;
259 default:
260 break;
261 }
262 }
263 }
264 }
265
ConvertLocalAccessChains(Function * func)266 Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
267 Function* func) {
268 FindTargetVars(func);
269 // Replace access chains of all targeted variables with equivalent
270 // extract and insert sequences
271 bool modified = false;
272 for (auto bi = func->begin(); bi != func->end(); ++bi) {
273 std::vector<Instruction*> dead_instructions;
274 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
275 switch (ii->opcode()) {
276 case spv::Op::OpLoad: {
277 uint32_t varId;
278 Instruction* ptrInst = GetPtr(&*ii, &varId);
279 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
280 if (!IsTargetVar(varId)) break;
281 if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
282 return Status::Failure;
283 }
284 modified = true;
285 } break;
286 case spv::Op::OpStore: {
287 uint32_t varId;
288 Instruction* store = &*ii;
289 Instruction* ptrInst = GetPtr(store, &varId);
290 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
291 if (!IsTargetVar(varId)) break;
292 std::vector<std::unique_ptr<Instruction>> newInsts;
293 uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
294 if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
295 return Status::Failure;
296 }
297 size_t num_of_instructions_to_skip = newInsts.size() - 1;
298 dead_instructions.push_back(store);
299 ++ii;
300 ii = ii.InsertBefore(std::move(newInsts));
301 for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
302 ii->UpdateDebugInfoFrom(store);
303 context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
304 ++ii;
305 }
306 ii->UpdateDebugInfoFrom(store);
307 context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
308 modified = true;
309 } break;
310 default:
311 break;
312 }
313 }
314
315 while (!dead_instructions.empty()) {
316 Instruction* inst = dead_instructions.back();
317 dead_instructions.pop_back();
318 DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
319 auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
320 other_inst);
321 if (i != dead_instructions.end()) {
322 dead_instructions.erase(i);
323 }
324 });
325 }
326 }
327 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
328 }
329
Initialize()330 void LocalAccessChainConvertPass::Initialize() {
331 // Initialize Target Variable Caches
332 seen_target_vars_.clear();
333 seen_non_target_vars_.clear();
334
335 // Initialize collections
336 supported_ref_ptrs_.clear();
337
338 // Initialize extension allowlist
339 InitExtensions();
340 }
341
AllExtensionsSupported() const342 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
343 // This capability can now exist without the extension, so we have to check
344 // for the capability. This pass is only looking at function scope symbols,
345 // so we do not care if there are variable pointers on storage buffers.
346 if (context()->get_feature_mgr()->HasCapability(
347 spv::Capability::VariablePointers))
348 return false;
349 // If any extension not in allowlist, return false
350 for (auto& ei : get_module()->extensions()) {
351 const std::string extName = ei.GetInOperand(0).AsString();
352 if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
353 return false;
354 }
355 // only allow NonSemantic.Shader.DebugInfo.100, we cannot safely optimise
356 // around unknown extended
357 // instruction sets even if they are non-semantic
358 for (auto& inst : context()->module()->ext_inst_imports()) {
359 assert(inst.opcode() == spv::Op::OpExtInstImport &&
360 "Expecting an import of an extension's instruction set.");
361 const std::string extension_name = inst.GetInOperand(0).AsString();
362 if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
363 extension_name != "NonSemantic.Shader.DebugInfo.100") {
364 return false;
365 }
366 }
367 return true;
368 }
369
ProcessImpl()370 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
371 // Do not process if module contains OpGroupDecorate. Additional
372 // support required in KillNamesAndDecorates().
373 // TODO(greg-lunarg): Add support for OpGroupDecorate
374 for (auto& ai : get_module()->annotations())
375 if (ai.opcode() == spv::Op::OpGroupDecorate)
376 return Status::SuccessWithoutChange;
377 // Do not process if any disallowed extensions are enabled
378 if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
379
380 // Process all functions in the module.
381 Status status = Status::SuccessWithoutChange;
382 for (Function& func : *get_module()) {
383 status = CombineStatus(status, ConvertLocalAccessChains(&func));
384 if (status == Status::Failure) {
385 break;
386 }
387 }
388 return status;
389 }
390
LocalAccessChainConvertPass()391 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
392
Process()393 Pass::Status LocalAccessChainConvertPass::Process() {
394 Initialize();
395 return ProcessImpl();
396 }
397
InitExtensions()398 void LocalAccessChainConvertPass::InitExtensions() {
399 extensions_allowlist_.clear();
400 extensions_allowlist_.insert(
401 {"SPV_AMD_shader_explicit_vertex_parameter",
402 "SPV_AMD_shader_trinary_minmax", "SPV_AMD_gcn_shader",
403 "SPV_KHR_shader_ballot", "SPV_AMD_shader_ballot",
404 "SPV_AMD_gpu_shader_half_float", "SPV_KHR_shader_draw_parameters",
405 "SPV_KHR_subgroup_vote", "SPV_KHR_8bit_storage", "SPV_KHR_16bit_storage",
406 "SPV_KHR_device_group", "SPV_KHR_multiview",
407 "SPV_NVX_multiview_per_view_attributes", "SPV_NV_viewport_array2",
408 "SPV_NV_stereo_view_rendering", "SPV_NV_sample_mask_override_coverage",
409 "SPV_NV_geometry_shader_passthrough", "SPV_AMD_texture_gather_bias_lod",
410 "SPV_KHR_storage_buffer_storage_class",
411 // SPV_KHR_variable_pointers
412 // Currently do not support extended pointer expressions
413 "SPV_AMD_gpu_shader_int16", "SPV_KHR_post_depth_coverage",
414 "SPV_KHR_shader_atomic_counter_ops", "SPV_EXT_shader_stencil_export",
415 "SPV_EXT_shader_viewport_index_layer",
416 "SPV_AMD_shader_image_load_store_lod", "SPV_AMD_shader_fragment_mask",
417 "SPV_EXT_fragment_fully_covered", "SPV_AMD_gpu_shader_half_float_fetch",
418 "SPV_GOOGLE_decorate_string", "SPV_GOOGLE_hlsl_functionality1",
419 "SPV_GOOGLE_user_type", "SPV_NV_shader_subgroup_partitioned",
420 "SPV_EXT_demote_to_helper_invocation", "SPV_EXT_descriptor_indexing",
421 "SPV_NV_fragment_shader_barycentric",
422 "SPV_NV_compute_shader_derivatives", "SPV_NV_shader_image_footprint",
423 "SPV_NV_shading_rate", "SPV_NV_mesh_shader", "SPV_NV_ray_tracing",
424 "SPV_KHR_ray_tracing", "SPV_KHR_ray_query",
425 "SPV_EXT_fragment_invocation_density", "SPV_KHR_terminate_invocation",
426 "SPV_KHR_subgroup_uniform_control_flow", "SPV_KHR_integer_dot_product",
427 "SPV_EXT_shader_image_int64", "SPV_KHR_non_semantic_info",
428 "SPV_KHR_uniform_group_instructions",
429 "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
430 "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
431 "SPV_EXT_fragment_shader_interlock",
432 "SPV_NV_compute_shader_derivatives"});
433 }
434
AnyIndexIsOutOfBounds(const Instruction * access_chain_inst)435 bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
436 const Instruction* access_chain_inst) {
437 assert(IsNonPtrAccessChain(access_chain_inst->opcode()));
438
439 analysis::TypeManager* type_mgr = context()->get_type_mgr();
440 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
441 auto constants = const_mgr->GetOperandConstants(access_chain_inst);
442 uint32_t base_pointer_id = access_chain_inst->GetSingleWordInOperand(0);
443 Instruction* base_pointer = get_def_use_mgr()->GetDef(base_pointer_id);
444 const analysis::Pointer* base_pointer_type =
445 type_mgr->GetType(base_pointer->type_id())->AsPointer();
446 assert(base_pointer_type != nullptr &&
447 "The base of the access chain is not a pointer.");
448 const analysis::Type* current_type = base_pointer_type->pointee_type();
449 for (uint32_t i = 1; i < access_chain_inst->NumInOperands(); ++i) {
450 if (IsIndexOutOfBounds(constants[i], current_type)) {
451 return true;
452 }
453
454 uint32_t index =
455 (constants[i]
456 ? static_cast<uint32_t>(constants[i]->GetZeroExtendedValue())
457 : 0);
458 current_type = type_mgr->GetMemberType(current_type, {index});
459 }
460
461 return false;
462 }
463
IsIndexOutOfBounds(const analysis::Constant * index,const analysis::Type * type) const464 bool LocalAccessChainConvertPass::IsIndexOutOfBounds(
465 const analysis::Constant* index, const analysis::Type* type) const {
466 if (index == nullptr) {
467 return false;
468 }
469 return index->GetZeroExtendedValue() >= type->NumberOfComponents();
470 }
471
472 } // namespace opt
473 } // namespace spvtools
474