1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "source/opt/local_access_chain_convert_pass.h"
18
19 #include "ir_builder.h"
20 #include "ir_context.h"
21 #include "iterator.h"
22
23 namespace spvtools {
24 namespace opt {
25
26 namespace {
27
28 const uint32_t kStoreValIdInIdx = 1;
29 const uint32_t kAccessChainPtrIdInIdx = 0;
30 const uint32_t kConstantValueInIdx = 0;
31 const uint32_t kTypeIntWidthInIdx = 0;
32
33 } // anonymous namespace
34
BuildAndAppendInst(SpvOp opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)35 void LocalAccessChainConvertPass::BuildAndAppendInst(
36 SpvOp opcode, uint32_t typeId, uint32_t resultId,
37 const std::vector<Operand>& in_opnds,
38 std::vector<std::unique_ptr<Instruction>>* newInsts) {
39 std::unique_ptr<Instruction> newInst(
40 new Instruction(context(), opcode, typeId, resultId, in_opnds));
41 get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
42 newInsts->emplace_back(std::move(newInst));
43 }
44
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)45 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
46 const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
47 std::vector<std::unique_ptr<Instruction>>* newInsts) {
48 const uint32_t ldResultId = TakeNextId();
49 if (ldResultId == 0) {
50 return 0;
51 }
52
53 *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
54 const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
55 assert(varInst->opcode() == SpvOpVariable);
56 *varPteTypeId = GetPointeeTypeId(varInst);
57 BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
58 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
59 newInsts);
60 return ldResultId;
61 }
62
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)63 void LocalAccessChainConvertPass::AppendConstantOperands(
64 const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
65 uint32_t iidIdx = 0;
66 ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
67 if (iidIdx > 0) {
68 const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
69 uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
70 in_opnds->push_back(
71 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
72 }
73 ++iidIdx;
74 });
75 }
76
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)77 bool LocalAccessChainConvertPass::ReplaceAccessChainLoad(
78 const Instruction* address_inst, Instruction* original_load) {
79 // Build and append load of variable in ptrInst
80 if (address_inst->NumInOperands() == 1) {
81 // An access chain with no indices is essentially a copy. All that is
82 // needed is to propagate the address.
83 context()->ReplaceAllUsesWith(
84 address_inst->result_id(),
85 address_inst->GetSingleWordInOperand(kAccessChainPtrIdInIdx));
86 return true;
87 }
88
89 std::vector<std::unique_ptr<Instruction>> new_inst;
90 uint32_t varId;
91 uint32_t varPteTypeId;
92 const uint32_t ldResultId =
93 BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
94 if (ldResultId == 0) {
95 return false;
96 }
97
98 new_inst[0]->UpdateDebugInfoFrom(original_load);
99 context()->get_decoration_mgr()->CloneDecorations(
100 original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision});
101 original_load->InsertBefore(std::move(new_inst));
102 context()->get_debug_info_mgr()->AnalyzeDebugInst(
103 original_load->PreviousNode());
104
105 // Rewrite |original_load| into an extract.
106 Instruction::OperandList new_operands;
107
108 // copy the result id and the type id to the new operand list.
109 new_operands.emplace_back(original_load->GetOperand(0));
110 new_operands.emplace_back(original_load->GetOperand(1));
111
112 new_operands.emplace_back(
113 Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
114 AppendConstantOperands(address_inst, &new_operands);
115 original_load->SetOpcode(SpvOpCompositeExtract);
116 original_load->ReplaceOperands(new_operands);
117 context()->UpdateDefUse(original_load);
118 return true;
119 }
120
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)121 bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
122 const Instruction* ptrInst, uint32_t valId,
123 std::vector<std::unique_ptr<Instruction>>* newInsts) {
124 if (ptrInst->NumInOperands() == 1) {
125 // An access chain with no indices is essentially a copy. However, we still
126 // have to create a new store because the old ones will be deleted.
127 BuildAndAppendInst(
128 SpvOpStore, 0, 0,
129 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
130 {ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx)}},
131 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}}},
132 newInsts);
133 return true;
134 }
135
136 // Build and append load of variable in ptrInst
137 uint32_t varId;
138 uint32_t varPteTypeId;
139 const uint32_t ldResultId =
140 BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
141 if (ldResultId == 0) {
142 return false;
143 }
144
145 context()->get_decoration_mgr()->CloneDecorations(
146 varId, ldResultId, {SpvDecorationRelaxedPrecision});
147
148 // Build and append Insert
149 const uint32_t insResultId = TakeNextId();
150 if (insResultId == 0) {
151 return false;
152 }
153 std::vector<Operand> ins_in_opnds = {
154 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
155 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
156 AppendConstantOperands(ptrInst, &ins_in_opnds);
157 BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId,
158 ins_in_opnds, newInsts);
159
160 context()->get_decoration_mgr()->CloneDecorations(
161 varId, insResultId, {SpvDecorationRelaxedPrecision});
162
163 // Build and append Store
164 BuildAndAppendInst(SpvOpStore, 0, 0,
165 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
166 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
167 newInsts);
168 return true;
169 }
170
IsConstantIndexAccessChain(const Instruction * acp) const171 bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
172 const Instruction* acp) const {
173 uint32_t inIdx = 0;
174 return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
175 if (inIdx > 0) {
176 Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
177 if (opInst->opcode() != SpvOpConstant) return false;
178 }
179 ++inIdx;
180 return true;
181 });
182 }
183
HasOnlySupportedRefs(uint32_t ptrId)184 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
185 if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
186 if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
187 if (user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugValue ||
188 user->GetOpenCL100DebugOpcode() == OpenCLDebugInfo100DebugDeclare) {
189 return true;
190 }
191 SpvOp op = user->opcode();
192 if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
193 if (!HasOnlySupportedRefs(user->result_id())) {
194 return false;
195 }
196 } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
197 !IsNonTypeDecorate(op)) {
198 return false;
199 }
200 return true;
201 })) {
202 supported_ref_ptrs_.insert(ptrId);
203 return true;
204 }
205 return false;
206 }
207
FindTargetVars(Function * func)208 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
209 for (auto bi = func->begin(); bi != func->end(); ++bi) {
210 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
211 switch (ii->opcode()) {
212 case SpvOpStore:
213 case SpvOpLoad: {
214 uint32_t varId;
215 Instruction* ptrInst = GetPtr(&*ii, &varId);
216 if (!IsTargetVar(varId)) break;
217 const SpvOp op = ptrInst->opcode();
218 // Rule out variables with non-supported refs eg function calls
219 if (!HasOnlySupportedRefs(varId)) {
220 seen_non_target_vars_.insert(varId);
221 seen_target_vars_.erase(varId);
222 break;
223 }
224 // Rule out variables with nested access chains
225 // TODO(): Convert nested access chains
226 if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand(
227 kAccessChainPtrIdInIdx) != varId) {
228 seen_non_target_vars_.insert(varId);
229 seen_target_vars_.erase(varId);
230 break;
231 }
232 // Rule out variables accessed with non-constant indices
233 if (!IsConstantIndexAccessChain(ptrInst)) {
234 seen_non_target_vars_.insert(varId);
235 seen_target_vars_.erase(varId);
236 break;
237 }
238 } break;
239 default:
240 break;
241 }
242 }
243 }
244 }
245
ConvertLocalAccessChains(Function * func)246 Pass::Status LocalAccessChainConvertPass::ConvertLocalAccessChains(
247 Function* func) {
248 FindTargetVars(func);
249 // Replace access chains of all targeted variables with equivalent
250 // extract and insert sequences
251 bool modified = false;
252 for (auto bi = func->begin(); bi != func->end(); ++bi) {
253 std::vector<Instruction*> dead_instructions;
254 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
255 switch (ii->opcode()) {
256 case SpvOpLoad: {
257 uint32_t varId;
258 Instruction* ptrInst = GetPtr(&*ii, &varId);
259 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
260 if (!IsTargetVar(varId)) break;
261 if (!ReplaceAccessChainLoad(ptrInst, &*ii)) {
262 return Status::Failure;
263 }
264 modified = true;
265 } break;
266 case SpvOpStore: {
267 uint32_t varId;
268 Instruction* store = &*ii;
269 Instruction* ptrInst = GetPtr(store, &varId);
270 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
271 if (!IsTargetVar(varId)) break;
272 std::vector<std::unique_ptr<Instruction>> newInsts;
273 uint32_t valId = store->GetSingleWordInOperand(kStoreValIdInIdx);
274 if (!GenAccessChainStoreReplacement(ptrInst, valId, &newInsts)) {
275 return Status::Failure;
276 }
277 size_t num_of_instructions_to_skip = newInsts.size() - 1;
278 dead_instructions.push_back(store);
279 ++ii;
280 ii = ii.InsertBefore(std::move(newInsts));
281 for (size_t i = 0; i < num_of_instructions_to_skip; ++i) {
282 ii->UpdateDebugInfoFrom(store);
283 context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
284 ++ii;
285 }
286 ii->UpdateDebugInfoFrom(store);
287 context()->get_debug_info_mgr()->AnalyzeDebugInst(&*ii);
288 modified = true;
289 } break;
290 default:
291 break;
292 }
293 }
294
295 while (!dead_instructions.empty()) {
296 Instruction* inst = dead_instructions.back();
297 dead_instructions.pop_back();
298 DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
299 auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
300 other_inst);
301 if (i != dead_instructions.end()) {
302 dead_instructions.erase(i);
303 }
304 });
305 }
306 }
307 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
308 }
309
Initialize()310 void LocalAccessChainConvertPass::Initialize() {
311 // Initialize Target Variable Caches
312 seen_target_vars_.clear();
313 seen_non_target_vars_.clear();
314
315 // Initialize collections
316 supported_ref_ptrs_.clear();
317
318 // Initialize extension allowlist
319 InitExtensions();
320 }
321
AllExtensionsSupported() const322 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
323 // This capability can now exist without the extension, so we have to check
324 // for the capability. This pass is only looking at function scope symbols,
325 // so we do not care if there are variable pointers on storage buffers.
326 if (context()->get_feature_mgr()->HasCapability(
327 SpvCapabilityVariablePointers))
328 return false;
329 // If any extension not in allowlist, return false
330 for (auto& ei : get_module()->extensions()) {
331 const char* extName =
332 reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
333 if (extensions_allowlist_.find(extName) == extensions_allowlist_.end())
334 return false;
335 }
336 return true;
337 }
338
ProcessImpl()339 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
340 // If non-32-bit integer type in module, terminate processing
341 // TODO(): Handle non-32-bit integer constants in access chains
342 for (const Instruction& inst : get_module()->types_values())
343 if (inst.opcode() == SpvOpTypeInt &&
344 inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
345 return Status::SuccessWithoutChange;
346 // Do not process if module contains OpGroupDecorate. Additional
347 // support required in KillNamesAndDecorates().
348 // TODO(greg-lunarg): Add support for OpGroupDecorate
349 for (auto& ai : get_module()->annotations())
350 if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
351 // Do not process if any disallowed extensions are enabled
352 if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
353
354 // Process all functions in the module.
355 Status status = Status::SuccessWithoutChange;
356 for (Function& func : *get_module()) {
357 status = CombineStatus(status, ConvertLocalAccessChains(&func));
358 if (status == Status::Failure) {
359 break;
360 }
361 }
362 return status;
363 }
364
LocalAccessChainConvertPass()365 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
366
Process()367 Pass::Status LocalAccessChainConvertPass::Process() {
368 Initialize();
369 return ProcessImpl();
370 }
371
InitExtensions()372 void LocalAccessChainConvertPass::InitExtensions() {
373 extensions_allowlist_.clear();
374 extensions_allowlist_.insert({
375 "SPV_AMD_shader_explicit_vertex_parameter",
376 "SPV_AMD_shader_trinary_minmax",
377 "SPV_AMD_gcn_shader",
378 "SPV_KHR_shader_ballot",
379 "SPV_AMD_shader_ballot",
380 "SPV_AMD_gpu_shader_half_float",
381 "SPV_KHR_shader_draw_parameters",
382 "SPV_KHR_subgroup_vote",
383 "SPV_KHR_8bit_storage",
384 "SPV_KHR_16bit_storage",
385 "SPV_KHR_device_group",
386 "SPV_KHR_multiview",
387 "SPV_NVX_multiview_per_view_attributes",
388 "SPV_NV_viewport_array2",
389 "SPV_NV_stereo_view_rendering",
390 "SPV_NV_sample_mask_override_coverage",
391 "SPV_NV_geometry_shader_passthrough",
392 "SPV_AMD_texture_gather_bias_lod",
393 "SPV_KHR_storage_buffer_storage_class",
394 // SPV_KHR_variable_pointers
395 // Currently do not support extended pointer expressions
396 "SPV_AMD_gpu_shader_int16",
397 "SPV_KHR_post_depth_coverage",
398 "SPV_KHR_shader_atomic_counter_ops",
399 "SPV_EXT_shader_stencil_export",
400 "SPV_EXT_shader_viewport_index_layer",
401 "SPV_AMD_shader_image_load_store_lod",
402 "SPV_AMD_shader_fragment_mask",
403 "SPV_EXT_fragment_fully_covered",
404 "SPV_AMD_gpu_shader_half_float_fetch",
405 "SPV_GOOGLE_decorate_string",
406 "SPV_GOOGLE_hlsl_functionality1",
407 "SPV_GOOGLE_user_type",
408 "SPV_NV_shader_subgroup_partitioned",
409 "SPV_EXT_demote_to_helper_invocation",
410 "SPV_EXT_descriptor_indexing",
411 "SPV_NV_fragment_shader_barycentric",
412 "SPV_NV_compute_shader_derivatives",
413 "SPV_NV_shader_image_footprint",
414 "SPV_NV_shading_rate",
415 "SPV_NV_mesh_shader",
416 "SPV_NV_ray_tracing",
417 "SPV_KHR_ray_tracing",
418 "SPV_KHR_ray_query",
419 "SPV_EXT_fragment_invocation_density",
420 "SPV_KHR_terminate_invocation",
421 });
422 }
423
424 } // namespace opt
425 } // namespace spvtools
426