1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "source/opt/local_access_chain_convert_pass.h"
18
19 #include "ir_builder.h"
20 #include "ir_context.h"
21 #include "iterator.h"
22 #include "source/util/string_utils.h"
23
24 namespace spvtools {
25 namespace opt {
26
27 namespace {
28
29 const uint32_t kStoreValIdInIdx = 1;
30 const uint32_t kAccessChainPtrIdInIdx = 0;
31
32 } // anonymous namespace
33
BuildAndAppendInst(SpvOp opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)34 void LocalAccessChainConvertPass::BuildAndAppendInst(
35 SpvOp opcode, uint32_t typeId, uint32_t resultId,
36 const std::vector<Operand>& in_opnds,
37 std::vector<std::unique_ptr<Instruction>>* newInsts) {
38 std::unique_ptr<Instruction> newInst(
39 new Instruction(context(), opcode, typeId, resultId, in_opnds));
40 get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
41 newInsts->emplace_back(std::move(newInst));
42 }
43
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)44 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
45 const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
46 std::vector<std::unique_ptr<Instruction>>* newInsts) {
47 const uint32_t ldResultId = TakeNextId();
48 if (ldResultId == 0) {
49 return 0;
50 }
51
52 *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
53 const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
54 assert(varInst->opcode() == SpvOpVariable);
55 *varPteTypeId = GetPointeeTypeId(varInst);
56 BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
57 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
58 newInsts);
59 return ldResultId;
60 }
61
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)62 void LocalAccessChainConvertPass::AppendConstantOperands(
63 const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
64 uint32_t iidIdx = 0;
65 ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
66 if (iidIdx > 0) {
67 const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
68 const auto* constant_value =
69 context()->get_constant_mgr()->GetConstantFromInst(cInst);
70 assert(constant_value != nullptr &&
71 "Expecting the index to be a constant.");
72
73 // We take the sign extended value because OpAccessChain interprets the
74 // index as signed.
75 int64_t long_value = constant_value->GetSignExtendedValue();
76 assert(long_value <= UINT32_MAX && long_value >= 0 &&
77 "The index value is too large for a composite insert or extract "
78 "instruction.");
79
80 uint32_t val = static_cast<uint32_t>(long_value);
81 in_opnds->push_back(
82 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
83 }
84 ++iidIdx;
85 });
86 }
87
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)88 bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
89 const Instruction* address_inst, Instruction* original_load) {
90 // Build and append load of variable in ptrInst
91 if (address_inst->NumInOperands() == 1) {
92 // An access chain with no indices is essentially a copy. All that is
93 // needed is to propagate the address.
94 context()->ReplaceAllUsesWith(
95 address_inst->result_id(),
96 address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
97 return true;
98 }
99
100 std::vector<std::unique_ptr<Instruction>> new_inst;
101 uint32_t varId;
102 uint32_t varPteTypeId;
103 const uint32_t ldResultId =
104 BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
105 if (ldResultId == 0) {
106 return false;
107 }
108
109 new_inst[0]->UpdateDebugInfoFrom(original_load);
110 context()->get_decoration_mgr()->CloneDecorations(
111 original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision});
112 original_load->InsertBefore(std::move(new_inst));
113 context()->get_debug_info_mgr()->AnalyzeDebugInst(
114 original_load->PreviousNode());
115
116 // Rewrite |original_load| into an extract.
117 Instruction::OperandList new_operands;
118
119 // copy the result id and the type id to the new operand list.
120 new_operands.emplace_back(original_load->GetOperand(0));
121 new_operands.emplace_back(original_load->GetOperand(1));
122
123 new_operands.emplace_back(
124 Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
125 AppendConstantOperands(address_inst, &new_operands);
126 original_load->SetOpcode(SpvOpCompositeExtract);
127 original_load->ReplaceOperands(new_operands);
128 context()->UpdateDefUse(original_load);
129 return true;
130 }
131
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)132 bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
133 const Instruction* ptrInst, uint32_t valId,
134 std::vector<std::unique_ptr<Instruction>>* newInsts) {
135 if (ptrInst->NumInOperands() == 1) {
136 // An access chain with no indices is essentially a copy. However, we still
137 // have to create a new store because the old ones will be deleted.
138 BuildAndAppendInst(
139 SpvOpStore, 0, 0,
140 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
141 {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
142 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
143 newInsts);
144 return true;
145 }
146
147 // Build and append load of variable in ptrInst
148 uint32_t varId;
149 uint32_t varPteTypeId;
150 const uint32_t ldResultId =
151 BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
152 if (ldResultId == 0) {
153 return false;
154 }
155
156 context()->get_decoration_mgr()->CloneDecorations(
157 varId, ldResultId, {SpvDecorationRelaxedPrecision});
158
159 // Build and append Insert
160 const uint32_t insResultId = TakeNextId();
161 if (insResultId == 0) {
162 return false;
163 }
164 std::vector<Operand> ins_in_opnds = {
165 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
166 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
167 AppendConstantOperands(ptrInst, &ins_in_opnds);
168 BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId,
169 ins_in_opnds, newInsts);
170
171 context()->get_decoration_mgr()->CloneDecorations(
172 varId, insResultId, {SpvDecorationRelaxedPrecision});
173
174 // Build and append Store
175 BuildAndAppendInst(SpvOpStore, 0, 0,
176 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
177 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
178 newInsts);
179 return true;
180 }
181
Is32BitConstantIndexAccessChain(const Instruction * acp) const182 bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
183 const Instruction* acp) const {
184 uint32_t inIdx = 0;
185 return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
186 if (inIdx > 0) {
187 Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
188 if (opInst->opcode() != SpvOpConstant) return false;
189 const auto* index =
190 context()->get_constant_mgr()->GetConstantFromInst(opInst);
191 int64_t index_value = index->GetSignExtendedValue();
192 if (index_value > UINT32_MAX) return false;
193 if (index_value < 0) return false;
194 }
195 ++inIdx;
196 return true;
197 });
198 }
199
HasOnlySupportedRefs(uint32_t ptrId)200 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
201 if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
202 if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
203 if (user->GetCommonDebugOpcode() == CommonDebugInfoDebugValue ||
204 user->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) {
205 return true;
206 }
207 SpvOp op = user->opcode();
208 if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
209 if (!HasOnlySupportedRefs(user->result_id())) {
210 return false;
211 }
212 } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
213 !IsNonTypeDecorate(op)) {
214 return false;
215 }
216 return true;
217 })) {
218 supported_ref_ptrs_.insert(ptrId);
219 return true;
220 }
221 return false;
222 }
223
FindTargetVars(Function * func)224 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
225 for (auto bi = func->begin(); bi != func->end(); ++bi) {
226 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
227 switch (ii->opcode()) {
228 case SpvOpStore:
229 case SpvOpLoad: {
230 uint32_t varId;
231 Instruction* ptrInst = GetPtr(&*ii, &varId);
232 if (!IsTargetVar(varId)) break;
233 const SpvOp op = ptrInst->opcode();
234 // Rule out variables with non-supported refs eg function calls
235 if (!HasOnlySupportedRefs(varId)) {
236 seen_non_target_vars_.insert(varId);
237 seen_target_vars_.erase(varId);
238 break;
239 }
240 // Rule out variables with nested access chains
241 // TODO(): Convert nested access chains
242 bool is_non_ptr_access_chain = IsNonPtrAccessChain(op);
243 if (is_non_ptr_access_chain && ptrInst->GetSingleWordInOperand(
244 kAccessChainPtrIdInIdx) != varId) {
245 seen_non_target_vars_.insert(varId);
246 seen_target_vars_.erase(varId);
247 break;
248 }
249 // Rule out variables accessed with non-constant indices
250 if (!Is32BitConstantIndexAccessChain(ptrInst)) {
251 seen_non_target_vars_.insert(varId);
252 seen_target_vars_.erase(varId);
253 break;
254 }
255
256 if (is_non_ptr_access_chain && AnyIndexIsOutOfBounds(ptrInst)) {
257 seen_non_target_vars_.insert(varId);
258 seen_target_vars_.erase(varId);
259 break;
260 }
261 } break;
262 default:
263 break;
264 }
265 }
266 }
267 }
268
ConvertLocalAccessChains(Function * func)269 Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
270 Function* func) {
271 FindTargetVars(func);
272 // Replace access chains of all targeted variables with equivalent
273 // extract and insert sequences
274 bool modified = false;
275 for (auto bi = func->begin(); bi != func->end(); ++bi) {
276 std::vector<Instruction*> dead_instructions;
277 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
278 switch (ii->opcode()) {
279 case SpvOpLoad: {
280 uint32_t varId;
281 Instruction* ptrInst = GetPtr(&*ii, &varId);
282 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
283 if (!IsTargetVar(varId)) break;
284 if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
285 return Status::Failure;
286 }
287 modified = true;
288 } break;
289 case SpvOpStore: {
290 uint32_t varId;
291 Instruction* store = &*ii;
292 Instruction* ptrInst = GetPtr(store, &varId);
293 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
294 if (!IsTargetVar(varId)) break;
295 std::vector<std::unique_ptr<Instruction>> newInsts;
296 uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
297 if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
298 return Status::Failure;
299 }
300 size_t num_of_instructions_to_skip = newInsts.size() - 1;
301 dead_instructions.push_back(store);
302 ++ii;
303 ii = ii.InsertBefore(std::move(newInsts));
304 for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
305 ii->UpdateDebugInfoFrom(store);
306 context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
307 ++ii;
308 }
309 ii->UpdateDebugInfoFrom(store);
310 context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
311 modified = true;
312 } break;
313 default:
314 break;
315 }
316 }
317
318 while (!dead_instructions.empty()) {
319 Instruction* inst = dead_instructions.back();
320 dead_instructions.pop_back();
321 DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
322 auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
323 other_inst);
324 if (i != dead_instructions.end()) {
325 dead_instructions.erase(i);
326 }
327 });
328 }
329 }
330 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
331 }
332
Initialize()333 void LocalAccessChainConvertPass::Initialize() {
334 // Initialize Target Variable Caches
335 seen_target_vars_.clear();
336 seen_non_target_vars_.clear();
337
338 // Initialize collections
339 supported_ref_ptrs_.clear();
340
341 // Initialize extension allowlist
342 InitExtensions();
343 }
344
AllExtensionsSupported() const345 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
346 // This capability can now exist without the extension, so we have to check
347 // for the capability. This pass is only looking at function scope symbols,
348 // so we do not care if there are variable pointers on storage buffers.
349 if (context()->get_feature_mgr()->HasCapability(
350 SpvCapabilityVariablePointers))
351 return false;
352 // If any extension not in allowlist, return false
353 for (auto& ei : get_module()->extensions()) {
354 const std::string extName = ei.GetInOperand(0).AsString();
355 if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
356 return false;
357 }
358 // only allow NonSemantic.Shader.DebugInfo.100, we cannot safely optimise
359 // around unknown extended
360 // instruction sets even if they are non-semantic
361 for (auto& inst : context()->module()->ext_inst_imports()) {
362 assert(inst.opcode() == SpvOpExtInstImport &&
363 "Expecting an import of an extension's instruction set.");
364 const std::string extension_name = inst.GetInOperand(0).AsString();
365 if (spvtools::utils::starts_with(extension_name, "NonSemantic.") &&
366 extension_name != "NonSemantic.Shader.DebugInfo.100") {
367 return false;
368 }
369 }
370 return true;
371 }
372
ProcessImpl()373 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
374 // Do not process if module contains OpGroupDecorate. Additional
375 // support required in KillNamesAndDecorates().
376 // TODO(greg-lunarg): Add support for OpGroupDecorate
377 for (auto& ai : get_module()->annotations())
378 if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
379 // Do not process if any disallowed extensions are enabled
380 if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
381
382 // Process all functions in the module.
383 Status status = Status::SuccessWithoutChange;
384 for (Function& func : *get_module()) {
385 status = CombineStatus(status, ConvertLocalAccessChains(&func));
386 if (status == Status::Failure) {
387 break;
388 }
389 }
390 return status;
391 }
392
LocalAccessChainConvertPass()393 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
394
Process()395 Pass::Status LocalAccessChainConvertPass::Process() {
396 Initialize();
397 return ProcessImpl();
398 }
399
InitExtensions()400 void LocalAccessChainConvertPass::InitExtensions() {
401 extensions_allowlist_.clear();
402 extensions_allowlist_.insert({
403 "SPV_AMD_shader_explicit_vertex_parameter",
404 "SPV_AMD_shader_trinary_minmax",
405 "SPV_AMD_gcn_shader",
406 "SPV_KHR_shader_ballot",
407 "SPV_AMD_shader_ballot",
408 "SPV_AMD_gpu_shader_half_float",
409 "SPV_KHR_shader_draw_parameters",
410 "SPV_KHR_subgroup_vote",
411 "SPV_KHR_8bit_storage",
412 "SPV_KHR_16bit_storage",
413 "SPV_KHR_device_group",
414 "SPV_KHR_multiview",
415 "SPV_NVX_multiview_per_view_attributes",
416 "SPV_NV_viewport_array2",
417 "SPV_NV_stereo_view_rendering",
418 "SPV_NV_sample_mask_override_coverage",
419 "SPV_NV_geometry_shader_passthrough",
420 "SPV_AMD_texture_gather_bias_lod",
421 "SPV_KHR_storage_buffer_storage_class",
422 // SPV_KHR_variable_pointers
423 // Currently do not support extended pointer expressions
424 "SPV_AMD_gpu_shader_int16",
425 "SPV_KHR_post_depth_coverage",
426 "SPV_KHR_shader_atomic_counter_ops",
427 "SPV_EXT_shader_stencil_export",
428 "SPV_EXT_shader_viewport_index_layer",
429 "SPV_AMD_shader_image_load_store_lod",
430 "SPV_AMD_shader_fragment_mask",
431 "SPV_EXT_fragment_fully_covered",
432 "SPV_AMD_gpu_shader_half_float_fetch",
433 "SPV_GOOGLE_decorate_string",
434 "SPV_GOOGLE_hlsl_functionality1",
435 "SPV_GOOGLE_user_type",
436 "SPV_NV_shader_subgroup_partitioned",
437 "SPV_EXT_demote_to_helper_invocation",
438 "SPV_EXT_descriptor_indexing",
439 "SPV_NV_fragment_shader_barycentric",
440 "SPV_NV_compute_shader_derivatives",
441 "SPV_NV_shader_image_footprint",
442 "SPV_NV_shading_rate",
443 "SPV_NV_mesh_shader",
444 "SPV_NV_ray_tracing",
445 "SPV_KHR_ray_tracing",
446 "SPV_KHR_ray_query",
447 "SPV_EXT_fragment_invocation_density",
448 "SPV_KHR_terminate_invocation",
449 "SPV_KHR_subgroup_uniform_control_flow",
450 "SPV_KHR_integer_dot_product",
451 "SPV_EXT_shader_image_int64",
452 "SPV_KHR_non_semantic_info",
453 "SPV_KHR_uniform_group_instructions",
454 "SPV_KHR_fragment_shader_barycentric",
455 });
456 }
457
AnyIndexIsOutOfBounds(const Instruction * access_chain_inst)458 bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
459 const Instruction* access_chain_inst) {
460 assert(IsNonPtrAccessChain(access_chain_inst->opcode()));
461
462 analysis::TypeManager* type_mgr = context()->get_type_mgr();
463 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
464 auto constants = const_mgr->GetOperandConstants(access_chain_inst);
465 uint32_t base_pointer_id = access_chain_inst->GetSingleWordInOperand(0);
466 Instruction* base_pointer = get_def_use_mgr()->GetDef(base_pointer_id);
467 const analysis::Pointer* base_pointer_type =
468 type_mgr->GetType(base_pointer->type_id())->AsPointer();
469 assert(base_pointer_type != nullptr &&
470 "The base of the access chain is not a pointer.");
471 const analysis::Type* current_type = base_pointer_type->pointee_type();
472 for (uint32_t i = 1; i < access_chain_inst->NumInOperands(); ++i) {
473 if (IsIndexOutOfBounds(constants[i], current_type)) {
474 return true;
475 }
476
477 uint32_t index =
478 (constants[i]
479 ? static_cast<uint32_t>(constants[i]->GetZeroExtendedValue())
480 : 0);
481 current_type = type_mgr->GetMemberType(current_type, {index});
482 }
483
484 return false;
485 }
486
IsIndexOutOfBounds(const analysis::Constant * index,const analysis::Type * type) const487 bool LocalAccessChainConvertPass::IsIndexOutOfBounds(
488 const analysis::Constant* index, const analysis::Type* type) const {
489 if (index == nullptr) {
490 return false;
491 }
492 return index->GetZeroExtendedValue() >= type->NumberOfComponents();
493 }
494
495 } // namespace opt
496 } // namespace spvtools
497