1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "inst_bindless_check_pass.h"
18
19 #include "source/spirv_constant.h"
20
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 // Input Operand Indices
25 constexpr int kSpvImageSampleImageIdInIdx = 0;
26 constexpr int kSpvSampledImageImageIdInIdx = 0;
27 constexpr int kSpvSampledImageSamplerIdInIdx = 1;
28 constexpr int kSpvImageSampledImageIdInIdx = 0;
29 constexpr int kSpvCopyObjectOperandIdInIdx = 0;
30 constexpr int kSpvLoadPtrIdInIdx = 0;
31 constexpr int kSpvAccessChainBaseIdInIdx = 0;
32 constexpr int kSpvAccessChainIndex0IdInIdx = 1;
33 constexpr int kSpvTypeArrayTypeIdInIdx = 0;
34 constexpr int kSpvVariableStorageClassInIdx = 0;
35 constexpr int kSpvTypePtrTypeIdInIdx = 1;
36 constexpr int kSpvTypeImageDim = 1;
37 constexpr int kSpvTypeImageDepth = 2;
38 constexpr int kSpvTypeImageArrayed = 3;
39 constexpr int kSpvTypeImageMS = 4;
40 } // namespace
41
42 // This is a stub function for use with Import linkage
43 // clang-format off
44 // GLSL:
45 //bool inst_bindless_check_desc(const uint shader_id, const uint inst_num, const uvec4 stage_info, const uint desc_set,
46 // const uint binding, const uint desc_index, const uint byte_offset) {
47 //}
48 // clang-format on
GenDescCheckFunctionId()49 uint32_t InstBindlessCheckPass::GenDescCheckFunctionId() {
50 enum {
51 kShaderId = 0,
52 kInstructionIndex = 1,
53 kStageInfo = 2,
54 kDescSet = 3,
55 kDescBinding = 4,
56 kDescIndex = 5,
57 kByteOffset = 6,
58 kNumArgs
59 };
60 if (check_desc_func_id_ != 0) {
61 return check_desc_func_id_;
62 }
63
64 analysis::TypeManager* type_mgr = context()->get_type_mgr();
65 const analysis::Integer* uint_type = GetInteger(32, false);
66 const analysis::Vector v4uint(uint_type, 4);
67 const analysis::Type* v4uint_type = type_mgr->GetRegisteredType(&v4uint);
68 std::vector<const analysis::Type*> param_types(kNumArgs, uint_type);
69 param_types[2] = v4uint_type;
70
71 const uint32_t func_id = TakeNextId();
72 std::unique_ptr<Function> func =
73 StartFunction(func_id, type_mgr->GetBoolType(), param_types);
74
75 func->SetFunctionEnd(EndFunction());
76
77 static const std::string func_name{"inst_bindless_check_desc"};
78 context()->AddFunctionDeclaration(std::move(func));
79 context()->AddDebug2Inst(NewName(func_id, func_name));
80 std::vector<Operand> operands{
81 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {func_id}},
82 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
83 {uint32_t(spv::Decoration::LinkageAttributes)}},
84 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_STRING,
85 utils::MakeVector(func_name.c_str())},
86 {spv_operand_type_t::SPV_OPERAND_TYPE_LINKAGE_TYPE,
87 {uint32_t(spv::LinkageType::Import)}},
88 };
89 get_decoration_mgr()->AddDecoration(spv::Op::OpDecorate, operands);
90
91 check_desc_func_id_ = func_id;
92 // Make sure function doesn't get processed by
93 // InstrumentPass::InstProcessCallTreeFromRoots()
94 param2output_func_id_[3] = func_id;
95 return check_desc_func_id_;
96 }
97
98 // clang-format off
99 // GLSL:
100 // result = inst_bindless_check_desc(shader_id, inst_idx, stage_info, desc_set, binding, desc_idx, offset);
101 //
102 // clang-format on
GenDescCheckCall(uint32_t inst_idx,uint32_t stage_idx,uint32_t var_id,uint32_t desc_idx_id,uint32_t offset_id,InstructionBuilder * builder)103 uint32_t InstBindlessCheckPass::GenDescCheckCall(
104 uint32_t inst_idx, uint32_t stage_idx, uint32_t var_id,
105 uint32_t desc_idx_id, uint32_t offset_id, InstructionBuilder* builder) {
106 const uint32_t func_id = GenDescCheckFunctionId();
107 const std::vector<uint32_t> args = {
108 builder->GetUintConstantId(shader_id_),
109 builder->GetUintConstantId(inst_idx),
110 GenStageInfo(stage_idx, builder),
111 builder->GetUintConstantId(var2desc_set_[var_id]),
112 builder->GetUintConstantId(var2binding_[var_id]),
113 GenUintCastCode(desc_idx_id, builder),
114 offset_id};
115 return GenReadFunctionCall(GetBoolId(), func_id, args, builder);
116 }
117
CloneOriginalImage(uint32_t old_image_id,InstructionBuilder * builder)118 uint32_t InstBindlessCheckPass::CloneOriginalImage(
119 uint32_t old_image_id, InstructionBuilder* builder) {
120 Instruction* new_image_inst;
121 Instruction* old_image_inst = get_def_use_mgr()->GetDef(old_image_id);
122 if (old_image_inst->opcode() == spv::Op::OpLoad) {
123 new_image_inst = builder->AddLoad(
124 old_image_inst->type_id(),
125 old_image_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx));
126 } else if (old_image_inst->opcode() == spv::Op::OpSampledImage) {
127 uint32_t clone_id = CloneOriginalImage(
128 old_image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx),
129 builder);
130 new_image_inst = builder->AddBinaryOp(
131 old_image_inst->type_id(), spv::Op::OpSampledImage, clone_id,
132 old_image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx));
133 } else if (old_image_inst->opcode() == spv::Op::OpImage) {
134 uint32_t clone_id = CloneOriginalImage(
135 old_image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx),
136 builder);
137 new_image_inst = builder->AddUnaryOp(old_image_inst->type_id(),
138 spv::Op::OpImage, clone_id);
139 } else {
140 assert(old_image_inst->opcode() == spv::Op::OpCopyObject &&
141 "expecting OpCopyObject");
142 uint32_t clone_id = CloneOriginalImage(
143 old_image_inst->GetSingleWordInOperand(kSpvCopyObjectOperandIdInIdx),
144 builder);
145 // Since we are cloning, no need to create new copy
146 new_image_inst = get_def_use_mgr()->GetDef(clone_id);
147 }
148 uid2offset_[new_image_inst->unique_id()] =
149 uid2offset_[old_image_inst->unique_id()];
150 uint32_t new_image_id = new_image_inst->result_id();
151 get_decoration_mgr()->CloneDecorations(old_image_id, new_image_id);
152 return new_image_id;
153 }
154
CloneOriginalReference(RefAnalysis * ref,InstructionBuilder * builder)155 uint32_t InstBindlessCheckPass::CloneOriginalReference(
156 RefAnalysis* ref, InstructionBuilder* builder) {
157 // If original is image based, start by cloning descriptor load
158 uint32_t new_image_id = 0;
159 if (ref->desc_load_id != 0) {
160 uint32_t old_image_id =
161 ref->ref_inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
162 new_image_id = CloneOriginalImage(old_image_id, builder);
163 }
164 // Clone original reference
165 std::unique_ptr<Instruction> new_ref_inst(ref->ref_inst->Clone(context()));
166 uint32_t ref_result_id = ref->ref_inst->result_id();
167 uint32_t new_ref_id = 0;
168 if (ref_result_id != 0) {
169 new_ref_id = TakeNextId();
170 new_ref_inst->SetResultId(new_ref_id);
171 }
172 // Update new ref with new image if created
173 if (new_image_id != 0)
174 new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id});
175 // Register new reference and add to new block
176 Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
177 uid2offset_[added_inst->unique_id()] =
178 uid2offset_[ref->ref_inst->unique_id()];
179 if (new_ref_id != 0)
180 get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
181 return new_ref_id;
182 }
183
GetImageId(Instruction * inst)184 uint32_t InstBindlessCheckPass::GetImageId(Instruction* inst) {
185 switch (inst->opcode()) {
186 case spv::Op::OpImageSampleImplicitLod:
187 case spv::Op::OpImageSampleExplicitLod:
188 case spv::Op::OpImageSampleDrefImplicitLod:
189 case spv::Op::OpImageSampleDrefExplicitLod:
190 case spv::Op::OpImageSampleProjImplicitLod:
191 case spv::Op::OpImageSampleProjExplicitLod:
192 case spv::Op::OpImageSampleProjDrefImplicitLod:
193 case spv::Op::OpImageSampleProjDrefExplicitLod:
194 case spv::Op::OpImageGather:
195 case spv::Op::OpImageDrefGather:
196 case spv::Op::OpImageQueryLod:
197 case spv::Op::OpImageSparseSampleImplicitLod:
198 case spv::Op::OpImageSparseSampleExplicitLod:
199 case spv::Op::OpImageSparseSampleDrefImplicitLod:
200 case spv::Op::OpImageSparseSampleDrefExplicitLod:
201 case spv::Op::OpImageSparseSampleProjImplicitLod:
202 case spv::Op::OpImageSparseSampleProjExplicitLod:
203 case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
204 case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
205 case spv::Op::OpImageSparseGather:
206 case spv::Op::OpImageSparseDrefGather:
207 case spv::Op::OpImageFetch:
208 case spv::Op::OpImageRead:
209 case spv::Op::OpImageQueryFormat:
210 case spv::Op::OpImageQueryOrder:
211 case spv::Op::OpImageQuerySizeLod:
212 case spv::Op::OpImageQuerySize:
213 case spv::Op::OpImageQueryLevels:
214 case spv::Op::OpImageQuerySamples:
215 case spv::Op::OpImageSparseFetch:
216 case spv::Op::OpImageSparseRead:
217 case spv::Op::OpImageWrite:
218 return inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
219 default:
220 break;
221 }
222 return 0;
223 }
224
GetPointeeTypeInst(Instruction * ptr_inst)225 Instruction* InstBindlessCheckPass::GetPointeeTypeInst(Instruction* ptr_inst) {
226 uint32_t pte_ty_id = GetPointeeTypeId(ptr_inst);
227 return get_def_use_mgr()->GetDef(pte_ty_id);
228 }
229
AnalyzeDescriptorReference(Instruction * ref_inst,RefAnalysis * ref)230 bool InstBindlessCheckPass::AnalyzeDescriptorReference(Instruction* ref_inst,
231 RefAnalysis* ref) {
232 ref->ref_inst = ref_inst;
233 if (ref_inst->opcode() == spv::Op::OpLoad ||
234 ref_inst->opcode() == spv::Op::OpStore) {
235 ref->desc_load_id = 0;
236 ref->ptr_id = ref_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
237 Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
238 if (ptr_inst->opcode() != spv::Op::OpAccessChain) return false;
239 ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
240 Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
241 if (var_inst->opcode() != spv::Op::OpVariable) return false;
242 spv::StorageClass storage_class = spv::StorageClass(
243 var_inst->GetSingleWordInOperand(kSpvVariableStorageClassInIdx));
244 switch (storage_class) {
245 case spv::StorageClass::Uniform:
246 case spv::StorageClass::StorageBuffer:
247 break;
248 default:
249 return false;
250 break;
251 }
252 // Check for deprecated storage block form
253 if (storage_class == spv::StorageClass::Uniform) {
254 uint32_t var_ty_id = var_inst->type_id();
255 Instruction* var_ty_inst = get_def_use_mgr()->GetDef(var_ty_id);
256 uint32_t ptr_ty_id =
257 var_ty_inst->GetSingleWordInOperand(kSpvTypePtrTypeIdInIdx);
258 Instruction* ptr_ty_inst = get_def_use_mgr()->GetDef(ptr_ty_id);
259 spv::Op ptr_ty_op = ptr_ty_inst->opcode();
260 uint32_t block_ty_id =
261 (ptr_ty_op == spv::Op::OpTypeArray ||
262 ptr_ty_op == spv::Op::OpTypeRuntimeArray)
263 ? ptr_ty_inst->GetSingleWordInOperand(kSpvTypeArrayTypeIdInIdx)
264 : ptr_ty_id;
265 assert(get_def_use_mgr()->GetDef(block_ty_id)->opcode() ==
266 spv::Op::OpTypeStruct &&
267 "unexpected block type");
268 bool block_found = get_decoration_mgr()->FindDecoration(
269 block_ty_id, uint32_t(spv::Decoration::Block),
270 [](const Instruction&) { return true; });
271 if (!block_found) {
272 // If block decoration not found, verify deprecated form of SSBO
273 bool buffer_block_found = get_decoration_mgr()->FindDecoration(
274 block_ty_id, uint32_t(spv::Decoration::BufferBlock),
275 [](const Instruction&) { return true; });
276 USE_ASSERT(buffer_block_found && "block decoration not found");
277 storage_class = spv::StorageClass::StorageBuffer;
278 }
279 }
280 ref->strg_class = uint32_t(storage_class);
281 Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
282 switch (desc_type_inst->opcode()) {
283 case spv::Op::OpTypeArray:
284 case spv::Op::OpTypeRuntimeArray:
285 // A load through a descriptor array will have at least 3 operands. We
286 // do not want to instrument loads of descriptors here which are part of
287 // an image-based reference.
288 if (ptr_inst->NumInOperands() < 3) return false;
289 ref->desc_idx_id =
290 ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
291 break;
292 default:
293 break;
294 }
295 auto decos =
296 context()->get_decoration_mgr()->GetDecorationsFor(ref->var_id, false);
297 for (const auto& deco : decos) {
298 spv::Decoration d = spv::Decoration(deco->GetSingleWordInOperand(1u));
299 if (d == spv::Decoration::DescriptorSet) {
300 ref->set = deco->GetSingleWordInOperand(2u);
301 } else if (d == spv::Decoration::Binding) {
302 ref->binding = deco->GetSingleWordInOperand(2u);
303 }
304 }
305 return true;
306 }
307 // Reference is not load or store. If not an image-based reference, return.
308 ref->image_id = GetImageId(ref_inst);
309 if (ref->image_id == 0) return false;
310 // Search for descriptor load
311 uint32_t desc_load_id = ref->image_id;
312 Instruction* desc_load_inst;
313 for (;;) {
314 desc_load_inst = get_def_use_mgr()->GetDef(desc_load_id);
315 if (desc_load_inst->opcode() == spv::Op::OpSampledImage)
316 desc_load_id =
317 desc_load_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx);
318 else if (desc_load_inst->opcode() == spv::Op::OpImage)
319 desc_load_id =
320 desc_load_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx);
321 else if (desc_load_inst->opcode() == spv::Op::OpCopyObject)
322 desc_load_id =
323 desc_load_inst->GetSingleWordInOperand(kSpvCopyObjectOperandIdInIdx);
324 else
325 break;
326 }
327 if (desc_load_inst->opcode() != spv::Op::OpLoad) {
328 // TODO(greg-lunarg): Handle additional possibilities?
329 return false;
330 }
331 ref->desc_load_id = desc_load_id;
332 ref->ptr_id = desc_load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
333 Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
334 if (ptr_inst->opcode() == spv::Op::OpVariable) {
335 ref->desc_idx_id = 0;
336 ref->var_id = ref->ptr_id;
337 } else if (ptr_inst->opcode() == spv::Op::OpAccessChain) {
338 if (ptr_inst->NumInOperands() != 2) {
339 assert(false && "unexpected bindless index number");
340 return false;
341 }
342 ref->desc_idx_id =
343 ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
344 ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
345 Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
346 if (var_inst->opcode() != spv::Op::OpVariable) {
347 assert(false && "unexpected bindless base");
348 return false;
349 }
350 } else {
351 // TODO(greg-lunarg): Handle additional possibilities?
352 return false;
353 }
354 auto decos =
355 context()->get_decoration_mgr()->GetDecorationsFor(ref->var_id, false);
356 for (const auto& deco : decos) {
357 spv::Decoration d = spv::Decoration(deco->GetSingleWordInOperand(1u));
358 if (d == spv::Decoration::DescriptorSet) {
359 ref->set = deco->GetSingleWordInOperand(2u);
360 } else if (d == spv::Decoration::Binding) {
361 ref->binding = deco->GetSingleWordInOperand(2u);
362 }
363 }
364 return true;
365 }
366
FindStride(uint32_t ty_id,uint32_t stride_deco)367 uint32_t InstBindlessCheckPass::FindStride(uint32_t ty_id,
368 uint32_t stride_deco) {
369 uint32_t stride = 0xdeadbeef;
370 bool found = get_decoration_mgr()->FindDecoration(
371 ty_id, stride_deco, [&stride](const Instruction& deco_inst) {
372 stride = deco_inst.GetSingleWordInOperand(2u);
373 return true;
374 });
375 USE_ASSERT(found && "stride not found");
376 return stride;
377 }
378
ByteSize(uint32_t ty_id,uint32_t matrix_stride,bool col_major,bool in_matrix)379 uint32_t InstBindlessCheckPass::ByteSize(uint32_t ty_id, uint32_t matrix_stride,
380 bool col_major, bool in_matrix) {
381 analysis::TypeManager* type_mgr = context()->get_type_mgr();
382 const analysis::Type* sz_ty = type_mgr->GetType(ty_id);
383 if (sz_ty->kind() == analysis::Type::kPointer) {
384 // Assuming PhysicalStorageBuffer pointer
385 return 8;
386 }
387 if (sz_ty->kind() == analysis::Type::kMatrix) {
388 assert(matrix_stride != 0 && "missing matrix stride");
389 const analysis::Matrix* m_ty = sz_ty->AsMatrix();
390 if (col_major) {
391 return m_ty->element_count() * matrix_stride;
392 } else {
393 const analysis::Vector* v_ty = m_ty->element_type()->AsVector();
394 return v_ty->element_count() * matrix_stride;
395 }
396 }
397 uint32_t size = 1;
398 if (sz_ty->kind() == analysis::Type::kVector) {
399 const analysis::Vector* v_ty = sz_ty->AsVector();
400 size = v_ty->element_count();
401 const analysis::Type* comp_ty = v_ty->element_type();
402 // if vector in row major matrix, the vector is strided so return the
403 // number of bytes spanned by the vector
404 if (in_matrix && !col_major && matrix_stride > 0) {
405 uint32_t comp_ty_id = type_mgr->GetId(comp_ty);
406 return (size - 1) * matrix_stride + ByteSize(comp_ty_id, 0, false, false);
407 }
408 sz_ty = comp_ty;
409 }
410 switch (sz_ty->kind()) {
411 case analysis::Type::kFloat: {
412 const analysis::Float* f_ty = sz_ty->AsFloat();
413 size *= f_ty->width();
414 } break;
415 case analysis::Type::kInteger: {
416 const analysis::Integer* i_ty = sz_ty->AsInteger();
417 size *= i_ty->width();
418 } break;
419 default: { assert(false && "unexpected type"); } break;
420 }
421 size /= 8;
422 return size;
423 }
424
GenLastByteIdx(RefAnalysis * ref,InstructionBuilder * builder)425 uint32_t InstBindlessCheckPass::GenLastByteIdx(RefAnalysis* ref,
426 InstructionBuilder* builder) {
427 // Find outermost buffer type and its access chain index
428 Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
429 Instruction* desc_ty_inst = GetPointeeTypeInst(var_inst);
430 uint32_t buff_ty_id;
431 uint32_t ac_in_idx = 1;
432 switch (desc_ty_inst->opcode()) {
433 case spv::Op::OpTypeArray:
434 case spv::Op::OpTypeRuntimeArray:
435 buff_ty_id = desc_ty_inst->GetSingleWordInOperand(0);
436 ++ac_in_idx;
437 break;
438 default:
439 assert(desc_ty_inst->opcode() == spv::Op::OpTypeStruct &&
440 "unexpected descriptor type");
441 buff_ty_id = desc_ty_inst->result_id();
442 break;
443 }
444 // Process remaining access chain indices
445 Instruction* ac_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
446 uint32_t curr_ty_id = buff_ty_id;
447 uint32_t sum_id = 0u;
448 uint32_t matrix_stride = 0u;
449 bool col_major = false;
450 uint32_t matrix_stride_id = 0u;
451 bool in_matrix = false;
452 while (ac_in_idx < ac_inst->NumInOperands()) {
453 uint32_t curr_idx_id = ac_inst->GetSingleWordInOperand(ac_in_idx);
454 Instruction* curr_ty_inst = get_def_use_mgr()->GetDef(curr_ty_id);
455 uint32_t curr_offset_id = 0;
456 switch (curr_ty_inst->opcode()) {
457 case spv::Op::OpTypeArray:
458 case spv::Op::OpTypeRuntimeArray: {
459 // Get array stride and multiply by current index
460 uint32_t arr_stride =
461 FindStride(curr_ty_id, uint32_t(spv::Decoration::ArrayStride));
462 uint32_t arr_stride_id = builder->GetUintConstantId(arr_stride);
463 uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
464 Instruction* curr_offset_inst = builder->AddBinaryOp(
465 GetUintId(), spv::Op::OpIMul, arr_stride_id, curr_idx_32b_id);
466 curr_offset_id = curr_offset_inst->result_id();
467 // Get element type for next step
468 curr_ty_id = curr_ty_inst->GetSingleWordInOperand(0);
469 } break;
470 case spv::Op::OpTypeMatrix: {
471 assert(matrix_stride != 0 && "missing matrix stride");
472 matrix_stride_id = builder->GetUintConstantId(matrix_stride);
473 uint32_t vec_ty_id = curr_ty_inst->GetSingleWordInOperand(0);
474 // If column major, multiply column index by matrix stride, otherwise
475 // by vector component size and save matrix stride for vector (row)
476 // index
477 uint32_t col_stride_id;
478 if (col_major) {
479 col_stride_id = matrix_stride_id;
480 } else {
481 Instruction* vec_ty_inst = get_def_use_mgr()->GetDef(vec_ty_id);
482 uint32_t comp_ty_id = vec_ty_inst->GetSingleWordInOperand(0u);
483 uint32_t col_stride = ByteSize(comp_ty_id, 0u, false, false);
484 col_stride_id = builder->GetUintConstantId(col_stride);
485 }
486 uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
487 Instruction* curr_offset_inst = builder->AddBinaryOp(
488 GetUintId(), spv::Op::OpIMul, col_stride_id, curr_idx_32b_id);
489 curr_offset_id = curr_offset_inst->result_id();
490 // Get element type for next step
491 curr_ty_id = vec_ty_id;
492 in_matrix = true;
493 } break;
494 case spv::Op::OpTypeVector: {
495 // If inside a row major matrix type, multiply index by matrix stride,
496 // else multiply by component size
497 uint32_t comp_ty_id = curr_ty_inst->GetSingleWordInOperand(0u);
498 uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
499 if (in_matrix && !col_major) {
500 Instruction* curr_offset_inst = builder->AddBinaryOp(
501 GetUintId(), spv::Op::OpIMul, matrix_stride_id, curr_idx_32b_id);
502 curr_offset_id = curr_offset_inst->result_id();
503 } else {
504 uint32_t comp_ty_sz = ByteSize(comp_ty_id, 0u, false, false);
505 uint32_t comp_ty_sz_id = builder->GetUintConstantId(comp_ty_sz);
506 Instruction* curr_offset_inst = builder->AddBinaryOp(
507 GetUintId(), spv::Op::OpIMul, comp_ty_sz_id, curr_idx_32b_id);
508 curr_offset_id = curr_offset_inst->result_id();
509 }
510 // Get element type for next step
511 curr_ty_id = comp_ty_id;
512 } break;
513 case spv::Op::OpTypeStruct: {
514 // Get buffer byte offset for the referenced member
515 Instruction* curr_idx_inst = get_def_use_mgr()->GetDef(curr_idx_id);
516 assert(curr_idx_inst->opcode() == spv::Op::OpConstant &&
517 "unexpected struct index");
518 uint32_t member_idx = curr_idx_inst->GetSingleWordInOperand(0);
519 uint32_t member_offset = 0xdeadbeef;
520 bool found = get_decoration_mgr()->FindDecoration(
521 curr_ty_id, uint32_t(spv::Decoration::Offset),
522 [&member_idx, &member_offset](const Instruction& deco_inst) {
523 if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
524 return false;
525 member_offset = deco_inst.GetSingleWordInOperand(3u);
526 return true;
527 });
528 USE_ASSERT(found && "member offset not found");
529 curr_offset_id = builder->GetUintConstantId(member_offset);
530 // Look for matrix stride for this member if there is one. The matrix
531 // stride is not on the matrix type, but in a OpMemberDecorate on the
532 // enclosing struct type at the member index. If none found, reset
533 // stride to 0.
534 found = get_decoration_mgr()->FindDecoration(
535 curr_ty_id, uint32_t(spv::Decoration::MatrixStride),
536 [&member_idx, &matrix_stride](const Instruction& deco_inst) {
537 if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
538 return false;
539 matrix_stride = deco_inst.GetSingleWordInOperand(3u);
540 return true;
541 });
542 if (!found) matrix_stride = 0;
543 // Look for column major decoration
544 found = get_decoration_mgr()->FindDecoration(
545 curr_ty_id, uint32_t(spv::Decoration::ColMajor),
546 [&member_idx, &col_major](const Instruction& deco_inst) {
547 if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
548 return false;
549 col_major = true;
550 return true;
551 });
552 if (!found) col_major = false;
553 // Get element type for next step
554 curr_ty_id = curr_ty_inst->GetSingleWordInOperand(member_idx);
555 } break;
556 default: { assert(false && "unexpected non-composite type"); } break;
557 }
558 if (sum_id == 0)
559 sum_id = curr_offset_id;
560 else {
561 Instruction* sum_inst =
562 builder->AddIAdd(GetUintId(), sum_id, curr_offset_id);
563 sum_id = sum_inst->result_id();
564 }
565 ++ac_in_idx;
566 }
567 // Add in offset of last byte of referenced object
568 uint32_t bsize = ByteSize(curr_ty_id, matrix_stride, col_major, in_matrix);
569 uint32_t last = bsize - 1;
570 uint32_t last_id = builder->GetUintConstantId(last);
571 Instruction* sum_inst = builder->AddIAdd(GetUintId(), sum_id, last_id);
572 return sum_inst->result_id();
573 }
574
GenCheckCode(uint32_t check_id,RefAnalysis * ref,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)575 void InstBindlessCheckPass::GenCheckCode(
576 uint32_t check_id, RefAnalysis* ref,
577 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
578 BasicBlock* back_blk_ptr = &*new_blocks->back();
579 InstructionBuilder builder(
580 context(), back_blk_ptr,
581 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
582 // Gen conditional branch on check_id. Valid branch generates original
583 // reference. Invalid generates debug output and zero result (if needed).
584 uint32_t merge_blk_id = TakeNextId();
585 uint32_t valid_blk_id = TakeNextId();
586 uint32_t invalid_blk_id = TakeNextId();
587 std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
588 std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
589 std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
590 (void)builder.AddConditionalBranch(
591 check_id, valid_blk_id, invalid_blk_id, merge_blk_id,
592 uint32_t(spv::SelectionControlMask::MaskNone));
593 // Gen valid bounds branch
594 std::unique_ptr<BasicBlock> new_blk_ptr(
595 new BasicBlock(std::move(valid_label)));
596 builder.SetInsertPoint(&*new_blk_ptr);
597 uint32_t new_ref_id = CloneOriginalReference(ref, &builder);
598 uint32_t null_id = 0;
599 uint32_t ref_type_id = ref->ref_inst->type_id();
600 (void)builder.AddBranch(merge_blk_id);
601 new_blocks->push_back(std::move(new_blk_ptr));
602 // Gen invalid block
603 new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
604 builder.SetInsertPoint(&*new_blk_ptr);
605
606 // Generate a ConstantNull, converting to uint64 if the type cannot be a null.
607 if (new_ref_id != 0) {
608 analysis::TypeManager* type_mgr = context()->get_type_mgr();
609 analysis::Type* ref_type = type_mgr->GetType(ref_type_id);
610 if (ref_type->AsPointer() != nullptr) {
611 context()->AddCapability(spv::Capability::Int64);
612 uint32_t null_u64_id = GetNullId(GetUint64Id());
613 Instruction* null_ptr_inst = builder.AddUnaryOp(
614 ref_type_id, spv::Op::OpConvertUToPtr, null_u64_id);
615 null_id = null_ptr_inst->result_id();
616 } else {
617 null_id = GetNullId(ref_type_id);
618 }
619 }
620 // Remember last invalid block id
621 uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id();
622 // Gen zero for invalid reference
623 (void)builder.AddBranch(merge_blk_id);
624 new_blocks->push_back(std::move(new_blk_ptr));
625 // Gen merge block
626 new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
627 builder.SetInsertPoint(&*new_blk_ptr);
628 // Gen phi of new reference and zero, if necessary, and replace the
629 // result id of the original reference with that of the Phi. Kill original
630 // reference.
631 if (new_ref_id != 0) {
632 Instruction* phi_inst = builder.AddPhi(
633 ref_type_id, {new_ref_id, valid_blk_id, null_id, last_invalid_blk_id});
634 context()->ReplaceAllUsesWith(ref->ref_inst->result_id(),
635 phi_inst->result_id());
636 }
637 new_blocks->push_back(std::move(new_blk_ptr));
638 context()->KillInst(ref->ref_inst);
639 }
640
GenDescCheckCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,uint32_t stage_idx,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)641 void InstBindlessCheckPass::GenDescCheckCode(
642 BasicBlock::iterator ref_inst_itr,
643 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
644 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
645 // Look for reference through descriptor. If not, return.
646 RefAnalysis ref;
647 if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
648 std::unique_ptr<BasicBlock> new_blk_ptr;
649 // Move original block's preceding instructions into first new block
650 MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
651 InstructionBuilder builder(
652 context(), &*new_blk_ptr,
653 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
654 new_blocks->push_back(std::move(new_blk_ptr));
655 // Determine if we can only do initialization check
656 uint32_t ref_id = builder.GetUintConstantId(0u);
657 spv::Op op = ref.ref_inst->opcode();
658 if (ref.desc_load_id != 0) {
659 uint32_t num_in_oprnds = ref.ref_inst->NumInOperands();
660 if ((op == spv::Op::OpImageRead && num_in_oprnds == 2) ||
661 (op == spv::Op::OpImageFetch && num_in_oprnds == 2) ||
662 (op == spv::Op::OpImageWrite && num_in_oprnds == 3)) {
663 Instruction* image_inst = get_def_use_mgr()->GetDef(ref.image_id);
664 uint32_t image_ty_id = image_inst->type_id();
665 Instruction* image_ty_inst = get_def_use_mgr()->GetDef(image_ty_id);
666 if (spv::Dim(image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDim)) ==
667 spv::Dim::Buffer) {
668 if ((image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDepth) == 0) &&
669 (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageArrayed) ==
670 0) &&
671 (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageMS) == 0)) {
672 ref_id = GenUintCastCode(ref.ref_inst->GetSingleWordInOperand(1),
673 &builder);
674 }
675 }
676 }
677 } else {
678 // For now, only do bounds check for non-aggregate types. Otherwise
679 // just do descriptor initialization check.
680 // TODO(greg-lunarg): Do bounds check for aggregate loads and stores
681 Instruction* ref_ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
682 Instruction* pte_type_inst = GetPointeeTypeInst(ref_ptr_inst);
683 spv::Op pte_type_op = pte_type_inst->opcode();
684 if (pte_type_op != spv::Op::OpTypeArray &&
685 pte_type_op != spv::Op::OpTypeRuntimeArray &&
686 pte_type_op != spv::Op::OpTypeStruct) {
687 ref_id = GenLastByteIdx(&ref, &builder);
688 }
689 }
690 // Read initialization/bounds from debug input buffer. If index id not yet
691 // set, binding is single descriptor, so set index to constant 0.
692 if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
693 uint32_t check_id =
694 GenDescCheckCall(ref.ref_inst->unique_id(), stage_idx, ref.var_id,
695 ref.desc_idx_id, ref_id, &builder);
696
697 // Generate runtime initialization/bounds test code with true branch
698 // being full reference and false branch being zero
699 // for the referenced value.
700 GenCheckCode(check_id, &ref, new_blocks);
701
702 // Move original block's remaining code into remainder/merge block and add
703 // to new blocks
704 BasicBlock* back_blk_ptr = &*new_blocks->back();
705 MovePostludeCode(ref_block_itr, back_blk_ptr);
706 }
707
InitializeInstBindlessCheck()708 void InstBindlessCheckPass::InitializeInstBindlessCheck() {
709 // Initialize base class
710 InitializeInstrument();
711 for (auto& anno : get_module()->annotations()) {
712 if (anno.opcode() == spv::Op::OpDecorate) {
713 if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
714 spv::Decoration::DescriptorSet) {
715 var2desc_set_[anno.GetSingleWordInOperand(0u)] =
716 anno.GetSingleWordInOperand(2u);
717 } else if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
718 spv::Decoration::Binding) {
719 var2binding_[anno.GetSingleWordInOperand(0u)] =
720 anno.GetSingleWordInOperand(2u);
721 }
722 }
723 }
724 }
725
ProcessImpl()726 Pass::Status InstBindlessCheckPass::ProcessImpl() {
727 // The memory model and linkage must always be updated for spirv-link to work
728 // correctly.
729 AddStorageBufferExt();
730 if (!get_feature_mgr()->HasExtension(kSPV_KHR_physical_storage_buffer)) {
731 context()->AddExtension("SPV_KHR_physical_storage_buffer");
732 }
733
734 context()->AddCapability(spv::Capability::PhysicalStorageBufferAddresses);
735 Instruction* memory_model = get_module()->GetMemoryModel();
736 memory_model->SetInOperand(
737 0u, {uint32_t(spv::AddressingModel::PhysicalStorageBuffer64)});
738
739 context()->AddCapability(spv::Capability::Linkage);
740
741 InstProcessFunction pfn =
742 [this](BasicBlock::iterator ref_inst_itr,
743 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
744 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
745 return GenDescCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
746 new_blocks);
747 };
748
749 InstProcessEntryPointCallTree(pfn);
750 // This pass always changes the memory model, so that linking will work
751 // properly.
752 return Status::SuccessWithChange;
753 }
754
Process()755 Pass::Status InstBindlessCheckPass::Process() {
756 InitializeInstBindlessCheck();
757 return ProcessImpl();
758 }
759
760 } // namespace opt
761 } // namespace spvtools
762