• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "inst_bindless_check_pass.h"
18 
19 #include "source/spirv_constant.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 // Input Operand Indices
25 constexpr int kSpvImageSampleImageIdInIdx = 0;
26 constexpr int kSpvSampledImageImageIdInIdx = 0;
27 constexpr int kSpvSampledImageSamplerIdInIdx = 1;
28 constexpr int kSpvImageSampledImageIdInIdx = 0;
29 constexpr int kSpvCopyObjectOperandIdInIdx = 0;
30 constexpr int kSpvLoadPtrIdInIdx = 0;
31 constexpr int kSpvAccessChainBaseIdInIdx = 0;
32 constexpr int kSpvAccessChainIndex0IdInIdx = 1;
33 constexpr int kSpvTypeArrayTypeIdInIdx = 0;
34 constexpr int kSpvVariableStorageClassInIdx = 0;
35 constexpr int kSpvTypePtrTypeIdInIdx = 1;
36 constexpr int kSpvTypeImageDim = 1;
37 constexpr int kSpvTypeImageDepth = 2;
38 constexpr int kSpvTypeImageArrayed = 3;
39 constexpr int kSpvTypeImageMS = 4;
40 }  // namespace
41 
42 // This is a stub function for use with Import linkage
43 // clang-format off
44 // GLSL:
45 //bool inst_bindless_check_desc(const uint shader_id, const uint inst_num, const uvec4 stage_info, const uint desc_set,
46 //                              const uint binding, const uint desc_index, const uint byte_offset) {
47 //}
48 // clang-format on
GenDescCheckFunctionId()49 uint32_t InstBindlessCheckPass::GenDescCheckFunctionId() {
50   enum {
51     kShaderId = 0,
52     kInstructionIndex = 1,
53     kStageInfo = 2,
54     kDescSet = 3,
55     kDescBinding = 4,
56     kDescIndex = 5,
57     kByteOffset = 6,
58     kNumArgs
59   };
60   if (check_desc_func_id_ != 0) {
61     return check_desc_func_id_;
62   }
63 
64   analysis::TypeManager* type_mgr = context()->get_type_mgr();
65   const analysis::Integer* uint_type = GetInteger(32, false);
66   const analysis::Vector v4uint(uint_type, 4);
67   const analysis::Type* v4uint_type = type_mgr->GetRegisteredType(&v4uint);
68   std::vector<const analysis::Type*> param_types(kNumArgs, uint_type);
69   param_types[2] = v4uint_type;
70 
71   const uint32_t func_id = TakeNextId();
72   std::unique_ptr<Function> func =
73       StartFunction(func_id, type_mgr->GetBoolType(), param_types);
74 
75   func->SetFunctionEnd(EndFunction());
76 
77   static const std::string func_name{"inst_bindless_check_desc"};
78   context()->AddFunctionDeclaration(std::move(func));
79   context()->AddDebug2Inst(NewName(func_id, func_name));
80   std::vector<Operand> operands{
81       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {func_id}},
82       {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
83        {uint32_t(spv::Decoration::LinkageAttributes)}},
84       {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_STRING,
85        utils::MakeVector(func_name.c_str())},
86       {spv_operand_type_t::SPV_OPERAND_TYPE_LINKAGE_TYPE,
87        {uint32_t(spv::LinkageType::Import)}},
88   };
89   get_decoration_mgr()->AddDecoration(spv::Op::OpDecorate, operands);
90 
91   check_desc_func_id_ = func_id;
92   // Make sure function doesn't get processed by
93   // InstrumentPass::InstProcessCallTreeFromRoots()
94   param2output_func_id_[3] = func_id;
95   return check_desc_func_id_;
96 }
97 
98 // clang-format off
99 // GLSL:
100 // result = inst_bindless_check_desc(shader_id, inst_idx, stage_info, desc_set, binding, desc_idx, offset);
101 //
102 // clang-format on
GenDescCheckCall(uint32_t inst_idx,uint32_t stage_idx,uint32_t var_id,uint32_t desc_idx_id,uint32_t offset_id,InstructionBuilder * builder)103 uint32_t InstBindlessCheckPass::GenDescCheckCall(
104     uint32_t inst_idx, uint32_t stage_idx, uint32_t var_id,
105     uint32_t desc_idx_id, uint32_t offset_id, InstructionBuilder* builder) {
106   const uint32_t func_id = GenDescCheckFunctionId();
107   const std::vector<uint32_t> args = {
108       builder->GetUintConstantId(shader_id_),
109       builder->GetUintConstantId(inst_idx),
110       GenStageInfo(stage_idx, builder),
111       builder->GetUintConstantId(var2desc_set_[var_id]),
112       builder->GetUintConstantId(var2binding_[var_id]),
113       GenUintCastCode(desc_idx_id, builder),
114       offset_id};
115   return GenReadFunctionCall(GetBoolId(), func_id, args, builder);
116 }
117 
CloneOriginalImage(uint32_t old_image_id,InstructionBuilder * builder)118 uint32_t InstBindlessCheckPass::CloneOriginalImage(
119     uint32_t old_image_id, InstructionBuilder* builder) {
120   Instruction* new_image_inst;
121   Instruction* old_image_inst = get_def_use_mgr()->GetDef(old_image_id);
122   if (old_image_inst->opcode() == spv::Op::OpLoad) {
123     new_image_inst = builder->AddLoad(
124         old_image_inst->type_id(),
125         old_image_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx));
126   } else if (old_image_inst->opcode() == spv::Op::OpSampledImage) {
127     uint32_t clone_id = CloneOriginalImage(
128         old_image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx),
129         builder);
130     new_image_inst = builder->AddBinaryOp(
131         old_image_inst->type_id(), spv::Op::OpSampledImage, clone_id,
132         old_image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx));
133   } else if (old_image_inst->opcode() == spv::Op::OpImage) {
134     uint32_t clone_id = CloneOriginalImage(
135         old_image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx),
136         builder);
137     new_image_inst = builder->AddUnaryOp(old_image_inst->type_id(),
138                                          spv::Op::OpImage, clone_id);
139   } else {
140     assert(old_image_inst->opcode() == spv::Op::OpCopyObject &&
141            "expecting OpCopyObject");
142     uint32_t clone_id = CloneOriginalImage(
143         old_image_inst->GetSingleWordInOperand(kSpvCopyObjectOperandIdInIdx),
144         builder);
145     // Since we are cloning, no need to create new copy
146     new_image_inst = get_def_use_mgr()->GetDef(clone_id);
147   }
148   uid2offset_[new_image_inst->unique_id()] =
149       uid2offset_[old_image_inst->unique_id()];
150   uint32_t new_image_id = new_image_inst->result_id();
151   get_decoration_mgr()->CloneDecorations(old_image_id, new_image_id);
152   return new_image_id;
153 }
154 
CloneOriginalReference(RefAnalysis * ref,InstructionBuilder * builder)155 uint32_t InstBindlessCheckPass::CloneOriginalReference(
156     RefAnalysis* ref, InstructionBuilder* builder) {
157   // If original is image based, start by cloning descriptor load
158   uint32_t new_image_id = 0;
159   if (ref->desc_load_id != 0) {
160     uint32_t old_image_id =
161         ref->ref_inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
162     new_image_id = CloneOriginalImage(old_image_id, builder);
163   }
164   // Clone original reference
165   std::unique_ptr<Instruction> new_ref_inst(ref->ref_inst->Clone(context()));
166   uint32_t ref_result_id = ref->ref_inst->result_id();
167   uint32_t new_ref_id = 0;
168   if (ref_result_id != 0) {
169     new_ref_id = TakeNextId();
170     new_ref_inst->SetResultId(new_ref_id);
171   }
172   // Update new ref with new image if created
173   if (new_image_id != 0)
174     new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id});
175   // Register new reference and add to new block
176   Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
177   uid2offset_[added_inst->unique_id()] =
178       uid2offset_[ref->ref_inst->unique_id()];
179   if (new_ref_id != 0)
180     get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
181   return new_ref_id;
182 }
183 
GetImageId(Instruction * inst)184 uint32_t InstBindlessCheckPass::GetImageId(Instruction* inst) {
185   switch (inst->opcode()) {
186     case spv::Op::OpImageSampleImplicitLod:
187     case spv::Op::OpImageSampleExplicitLod:
188     case spv::Op::OpImageSampleDrefImplicitLod:
189     case spv::Op::OpImageSampleDrefExplicitLod:
190     case spv::Op::OpImageSampleProjImplicitLod:
191     case spv::Op::OpImageSampleProjExplicitLod:
192     case spv::Op::OpImageSampleProjDrefImplicitLod:
193     case spv::Op::OpImageSampleProjDrefExplicitLod:
194     case spv::Op::OpImageGather:
195     case spv::Op::OpImageDrefGather:
196     case spv::Op::OpImageQueryLod:
197     case spv::Op::OpImageSparseSampleImplicitLod:
198     case spv::Op::OpImageSparseSampleExplicitLod:
199     case spv::Op::OpImageSparseSampleDrefImplicitLod:
200     case spv::Op::OpImageSparseSampleDrefExplicitLod:
201     case spv::Op::OpImageSparseSampleProjImplicitLod:
202     case spv::Op::OpImageSparseSampleProjExplicitLod:
203     case spv::Op::OpImageSparseSampleProjDrefImplicitLod:
204     case spv::Op::OpImageSparseSampleProjDrefExplicitLod:
205     case spv::Op::OpImageSparseGather:
206     case spv::Op::OpImageSparseDrefGather:
207     case spv::Op::OpImageFetch:
208     case spv::Op::OpImageRead:
209     case spv::Op::OpImageQueryFormat:
210     case spv::Op::OpImageQueryOrder:
211     case spv::Op::OpImageQuerySizeLod:
212     case spv::Op::OpImageQuerySize:
213     case spv::Op::OpImageQueryLevels:
214     case spv::Op::OpImageQuerySamples:
215     case spv::Op::OpImageSparseFetch:
216     case spv::Op::OpImageSparseRead:
217     case spv::Op::OpImageWrite:
218       return inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
219     default:
220       break;
221   }
222   return 0;
223 }
224 
GetPointeeTypeInst(Instruction * ptr_inst)225 Instruction* InstBindlessCheckPass::GetPointeeTypeInst(Instruction* ptr_inst) {
226   uint32_t pte_ty_id = GetPointeeTypeId(ptr_inst);
227   return get_def_use_mgr()->GetDef(pte_ty_id);
228 }
229 
AnalyzeDescriptorReference(Instruction * ref_inst,RefAnalysis * ref)230 bool InstBindlessCheckPass::AnalyzeDescriptorReference(Instruction* ref_inst,
231                                                        RefAnalysis* ref) {
232   ref->ref_inst = ref_inst;
233   if (ref_inst->opcode() == spv::Op::OpLoad ||
234       ref_inst->opcode() == spv::Op::OpStore) {
235     ref->desc_load_id = 0;
236     ref->ptr_id = ref_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
237     Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
238     if (ptr_inst->opcode() != spv::Op::OpAccessChain) return false;
239     ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
240     Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
241     if (var_inst->opcode() != spv::Op::OpVariable) return false;
242     spv::StorageClass storage_class = spv::StorageClass(
243         var_inst->GetSingleWordInOperand(kSpvVariableStorageClassInIdx));
244     switch (storage_class) {
245       case spv::StorageClass::Uniform:
246       case spv::StorageClass::StorageBuffer:
247         break;
248       default:
249         return false;
250         break;
251     }
252     // Check for deprecated storage block form
253     if (storage_class == spv::StorageClass::Uniform) {
254       uint32_t var_ty_id = var_inst->type_id();
255       Instruction* var_ty_inst = get_def_use_mgr()->GetDef(var_ty_id);
256       uint32_t ptr_ty_id =
257           var_ty_inst->GetSingleWordInOperand(kSpvTypePtrTypeIdInIdx);
258       Instruction* ptr_ty_inst = get_def_use_mgr()->GetDef(ptr_ty_id);
259       spv::Op ptr_ty_op = ptr_ty_inst->opcode();
260       uint32_t block_ty_id =
261           (ptr_ty_op == spv::Op::OpTypeArray ||
262            ptr_ty_op == spv::Op::OpTypeRuntimeArray)
263               ? ptr_ty_inst->GetSingleWordInOperand(kSpvTypeArrayTypeIdInIdx)
264               : ptr_ty_id;
265       assert(get_def_use_mgr()->GetDef(block_ty_id)->opcode() ==
266                  spv::Op::OpTypeStruct &&
267              "unexpected block type");
268       bool block_found = get_decoration_mgr()->FindDecoration(
269           block_ty_id, uint32_t(spv::Decoration::Block),
270           [](const Instruction&) { return true; });
271       if (!block_found) {
272         // If block decoration not found, verify deprecated form of SSBO
273         bool buffer_block_found = get_decoration_mgr()->FindDecoration(
274             block_ty_id, uint32_t(spv::Decoration::BufferBlock),
275             [](const Instruction&) { return true; });
276         USE_ASSERT(buffer_block_found && "block decoration not found");
277         storage_class = spv::StorageClass::StorageBuffer;
278       }
279     }
280     ref->strg_class = uint32_t(storage_class);
281     Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
282     switch (desc_type_inst->opcode()) {
283       case spv::Op::OpTypeArray:
284       case spv::Op::OpTypeRuntimeArray:
285         // A load through a descriptor array will have at least 3 operands. We
286         // do not want to instrument loads of descriptors here which are part of
287         // an image-based reference.
288         if (ptr_inst->NumInOperands() < 3) return false;
289         ref->desc_idx_id =
290             ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
291         break;
292       default:
293         break;
294     }
295     auto decos =
296         context()->get_decoration_mgr()->GetDecorationsFor(ref->var_id, false);
297     for (const auto& deco : decos) {
298       spv::Decoration d = spv::Decoration(deco->GetSingleWordInOperand(1u));
299       if (d == spv::Decoration::DescriptorSet) {
300         ref->set = deco->GetSingleWordInOperand(2u);
301       } else if (d == spv::Decoration::Binding) {
302         ref->binding = deco->GetSingleWordInOperand(2u);
303       }
304     }
305     return true;
306   }
307   // Reference is not load or store. If not an image-based reference, return.
308   ref->image_id = GetImageId(ref_inst);
309   if (ref->image_id == 0) return false;
310   // Search for descriptor load
311   uint32_t desc_load_id = ref->image_id;
312   Instruction* desc_load_inst;
313   for (;;) {
314     desc_load_inst = get_def_use_mgr()->GetDef(desc_load_id);
315     if (desc_load_inst->opcode() == spv::Op::OpSampledImage)
316       desc_load_id =
317           desc_load_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx);
318     else if (desc_load_inst->opcode() == spv::Op::OpImage)
319       desc_load_id =
320           desc_load_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx);
321     else if (desc_load_inst->opcode() == spv::Op::OpCopyObject)
322       desc_load_id =
323           desc_load_inst->GetSingleWordInOperand(kSpvCopyObjectOperandIdInIdx);
324     else
325       break;
326   }
327   if (desc_load_inst->opcode() != spv::Op::OpLoad) {
328     // TODO(greg-lunarg): Handle additional possibilities?
329     return false;
330   }
331   ref->desc_load_id = desc_load_id;
332   ref->ptr_id = desc_load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
333   Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
334   if (ptr_inst->opcode() == spv::Op::OpVariable) {
335     ref->desc_idx_id = 0;
336     ref->var_id = ref->ptr_id;
337   } else if (ptr_inst->opcode() == spv::Op::OpAccessChain) {
338     if (ptr_inst->NumInOperands() != 2) {
339       assert(false && "unexpected bindless index number");
340       return false;
341     }
342     ref->desc_idx_id =
343         ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
344     ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
345     Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
346     if (var_inst->opcode() != spv::Op::OpVariable) {
347       assert(false && "unexpected bindless base");
348       return false;
349     }
350   } else {
351     // TODO(greg-lunarg): Handle additional possibilities?
352     return false;
353   }
354   auto decos =
355       context()->get_decoration_mgr()->GetDecorationsFor(ref->var_id, false);
356   for (const auto& deco : decos) {
357     spv::Decoration d = spv::Decoration(deco->GetSingleWordInOperand(1u));
358     if (d == spv::Decoration::DescriptorSet) {
359       ref->set = deco->GetSingleWordInOperand(2u);
360     } else if (d == spv::Decoration::Binding) {
361       ref->binding = deco->GetSingleWordInOperand(2u);
362     }
363   }
364   return true;
365 }
366 
FindStride(uint32_t ty_id,uint32_t stride_deco)367 uint32_t InstBindlessCheckPass::FindStride(uint32_t ty_id,
368                                            uint32_t stride_deco) {
369   uint32_t stride = 0xdeadbeef;
370   bool found = get_decoration_mgr()->FindDecoration(
371       ty_id, stride_deco, [&stride](const Instruction& deco_inst) {
372         stride = deco_inst.GetSingleWordInOperand(2u);
373         return true;
374       });
375   USE_ASSERT(found && "stride not found");
376   return stride;
377 }
378 
ByteSize(uint32_t ty_id,uint32_t matrix_stride,bool col_major,bool in_matrix)379 uint32_t InstBindlessCheckPass::ByteSize(uint32_t ty_id, uint32_t matrix_stride,
380                                          bool col_major, bool in_matrix) {
381   analysis::TypeManager* type_mgr = context()->get_type_mgr();
382   const analysis::Type* sz_ty = type_mgr->GetType(ty_id);
383   if (sz_ty->kind() == analysis::Type::kPointer) {
384     // Assuming PhysicalStorageBuffer pointer
385     return 8;
386   }
387   if (sz_ty->kind() == analysis::Type::kMatrix) {
388     assert(matrix_stride != 0 && "missing matrix stride");
389     const analysis::Matrix* m_ty = sz_ty->AsMatrix();
390     if (col_major) {
391       return m_ty->element_count() * matrix_stride;
392     } else {
393       const analysis::Vector* v_ty = m_ty->element_type()->AsVector();
394       return v_ty->element_count() * matrix_stride;
395     }
396   }
397   uint32_t size = 1;
398   if (sz_ty->kind() == analysis::Type::kVector) {
399     const analysis::Vector* v_ty = sz_ty->AsVector();
400     size = v_ty->element_count();
401     const analysis::Type* comp_ty = v_ty->element_type();
402     // if vector in row major matrix, the vector is strided so return the
403     // number of bytes spanned by the vector
404     if (in_matrix && !col_major && matrix_stride > 0) {
405       uint32_t comp_ty_id = type_mgr->GetId(comp_ty);
406       return (size - 1) * matrix_stride + ByteSize(comp_ty_id, 0, false, false);
407     }
408     sz_ty = comp_ty;
409   }
410   switch (sz_ty->kind()) {
411     case analysis::Type::kFloat: {
412       const analysis::Float* f_ty = sz_ty->AsFloat();
413       size *= f_ty->width();
414     } break;
415     case analysis::Type::kInteger: {
416       const analysis::Integer* i_ty = sz_ty->AsInteger();
417       size *= i_ty->width();
418     } break;
419     default: { assert(false && "unexpected type"); } break;
420   }
421   size /= 8;
422   return size;
423 }
424 
GenLastByteIdx(RefAnalysis * ref,InstructionBuilder * builder)425 uint32_t InstBindlessCheckPass::GenLastByteIdx(RefAnalysis* ref,
426                                                InstructionBuilder* builder) {
427   // Find outermost buffer type and its access chain index
428   Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
429   Instruction* desc_ty_inst = GetPointeeTypeInst(var_inst);
430   uint32_t buff_ty_id;
431   uint32_t ac_in_idx = 1;
432   switch (desc_ty_inst->opcode()) {
433     case spv::Op::OpTypeArray:
434     case spv::Op::OpTypeRuntimeArray:
435       buff_ty_id = desc_ty_inst->GetSingleWordInOperand(0);
436       ++ac_in_idx;
437       break;
438     default:
439       assert(desc_ty_inst->opcode() == spv::Op::OpTypeStruct &&
440              "unexpected descriptor type");
441       buff_ty_id = desc_ty_inst->result_id();
442       break;
443   }
444   // Process remaining access chain indices
445   Instruction* ac_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
446   uint32_t curr_ty_id = buff_ty_id;
447   uint32_t sum_id = 0u;
448   uint32_t matrix_stride = 0u;
449   bool col_major = false;
450   uint32_t matrix_stride_id = 0u;
451   bool in_matrix = false;
452   while (ac_in_idx < ac_inst->NumInOperands()) {
453     uint32_t curr_idx_id = ac_inst->GetSingleWordInOperand(ac_in_idx);
454     Instruction* curr_ty_inst = get_def_use_mgr()->GetDef(curr_ty_id);
455     uint32_t curr_offset_id = 0;
456     switch (curr_ty_inst->opcode()) {
457       case spv::Op::OpTypeArray:
458       case spv::Op::OpTypeRuntimeArray: {
459         // Get array stride and multiply by current index
460         uint32_t arr_stride =
461             FindStride(curr_ty_id, uint32_t(spv::Decoration::ArrayStride));
462         uint32_t arr_stride_id = builder->GetUintConstantId(arr_stride);
463         uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
464         Instruction* curr_offset_inst = builder->AddBinaryOp(
465             GetUintId(), spv::Op::OpIMul, arr_stride_id, curr_idx_32b_id);
466         curr_offset_id = curr_offset_inst->result_id();
467         // Get element type for next step
468         curr_ty_id = curr_ty_inst->GetSingleWordInOperand(0);
469       } break;
470       case spv::Op::OpTypeMatrix: {
471         assert(matrix_stride != 0 && "missing matrix stride");
472         matrix_stride_id = builder->GetUintConstantId(matrix_stride);
473         uint32_t vec_ty_id = curr_ty_inst->GetSingleWordInOperand(0);
474         // If column major, multiply column index by matrix stride, otherwise
475         // by vector component size and save matrix stride for vector (row)
476         // index
477         uint32_t col_stride_id;
478         if (col_major) {
479           col_stride_id = matrix_stride_id;
480         } else {
481           Instruction* vec_ty_inst = get_def_use_mgr()->GetDef(vec_ty_id);
482           uint32_t comp_ty_id = vec_ty_inst->GetSingleWordInOperand(0u);
483           uint32_t col_stride = ByteSize(comp_ty_id, 0u, false, false);
484           col_stride_id = builder->GetUintConstantId(col_stride);
485         }
486         uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
487         Instruction* curr_offset_inst = builder->AddBinaryOp(
488             GetUintId(), spv::Op::OpIMul, col_stride_id, curr_idx_32b_id);
489         curr_offset_id = curr_offset_inst->result_id();
490         // Get element type for next step
491         curr_ty_id = vec_ty_id;
492         in_matrix = true;
493       } break;
494       case spv::Op::OpTypeVector: {
495         // If inside a row major matrix type, multiply index by matrix stride,
496         // else multiply by component size
497         uint32_t comp_ty_id = curr_ty_inst->GetSingleWordInOperand(0u);
498         uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
499         if (in_matrix && !col_major) {
500           Instruction* curr_offset_inst = builder->AddBinaryOp(
501               GetUintId(), spv::Op::OpIMul, matrix_stride_id, curr_idx_32b_id);
502           curr_offset_id = curr_offset_inst->result_id();
503         } else {
504           uint32_t comp_ty_sz = ByteSize(comp_ty_id, 0u, false, false);
505           uint32_t comp_ty_sz_id = builder->GetUintConstantId(comp_ty_sz);
506           Instruction* curr_offset_inst = builder->AddBinaryOp(
507               GetUintId(), spv::Op::OpIMul, comp_ty_sz_id, curr_idx_32b_id);
508           curr_offset_id = curr_offset_inst->result_id();
509         }
510         // Get element type for next step
511         curr_ty_id = comp_ty_id;
512       } break;
513       case spv::Op::OpTypeStruct: {
514         // Get buffer byte offset for the referenced member
515         Instruction* curr_idx_inst = get_def_use_mgr()->GetDef(curr_idx_id);
516         assert(curr_idx_inst->opcode() == spv::Op::OpConstant &&
517                "unexpected struct index");
518         uint32_t member_idx = curr_idx_inst->GetSingleWordInOperand(0);
519         uint32_t member_offset = 0xdeadbeef;
520         bool found = get_decoration_mgr()->FindDecoration(
521             curr_ty_id, uint32_t(spv::Decoration::Offset),
522             [&member_idx, &member_offset](const Instruction& deco_inst) {
523               if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
524                 return false;
525               member_offset = deco_inst.GetSingleWordInOperand(3u);
526               return true;
527             });
528         USE_ASSERT(found && "member offset not found");
529         curr_offset_id = builder->GetUintConstantId(member_offset);
530         // Look for matrix stride for this member if there is one. The matrix
531         // stride is not on the matrix type, but in a OpMemberDecorate on the
532         // enclosing struct type at the member index. If none found, reset
533         // stride to 0.
534         found = get_decoration_mgr()->FindDecoration(
535             curr_ty_id, uint32_t(spv::Decoration::MatrixStride),
536             [&member_idx, &matrix_stride](const Instruction& deco_inst) {
537               if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
538                 return false;
539               matrix_stride = deco_inst.GetSingleWordInOperand(3u);
540               return true;
541             });
542         if (!found) matrix_stride = 0;
543         // Look for column major decoration
544         found = get_decoration_mgr()->FindDecoration(
545             curr_ty_id, uint32_t(spv::Decoration::ColMajor),
546             [&member_idx, &col_major](const Instruction& deco_inst) {
547               if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
548                 return false;
549               col_major = true;
550               return true;
551             });
552         if (!found) col_major = false;
553         // Get element type for next step
554         curr_ty_id = curr_ty_inst->GetSingleWordInOperand(member_idx);
555       } break;
556       default: { assert(false && "unexpected non-composite type"); } break;
557     }
558     if (sum_id == 0)
559       sum_id = curr_offset_id;
560     else {
561       Instruction* sum_inst =
562           builder->AddIAdd(GetUintId(), sum_id, curr_offset_id);
563       sum_id = sum_inst->result_id();
564     }
565     ++ac_in_idx;
566   }
567   // Add in offset of last byte of referenced object
568   uint32_t bsize = ByteSize(curr_ty_id, matrix_stride, col_major, in_matrix);
569   uint32_t last = bsize - 1;
570   uint32_t last_id = builder->GetUintConstantId(last);
571   Instruction* sum_inst = builder->AddIAdd(GetUintId(), sum_id, last_id);
572   return sum_inst->result_id();
573 }
574 
GenCheckCode(uint32_t check_id,RefAnalysis * ref,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)575 void InstBindlessCheckPass::GenCheckCode(
576     uint32_t check_id, RefAnalysis* ref,
577     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
578   BasicBlock* back_blk_ptr = &*new_blocks->back();
579   InstructionBuilder builder(
580       context(), back_blk_ptr,
581       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
582   // Gen conditional branch on check_id. Valid branch generates original
583   // reference. Invalid generates debug output and zero result (if needed).
584   uint32_t merge_blk_id = TakeNextId();
585   uint32_t valid_blk_id = TakeNextId();
586   uint32_t invalid_blk_id = TakeNextId();
587   std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
588   std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
589   std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
590   (void)builder.AddConditionalBranch(
591       check_id, valid_blk_id, invalid_blk_id, merge_blk_id,
592       uint32_t(spv::SelectionControlMask::MaskNone));
593   // Gen valid bounds branch
594   std::unique_ptr<BasicBlock> new_blk_ptr(
595       new BasicBlock(std::move(valid_label)));
596   builder.SetInsertPoint(&*new_blk_ptr);
597   uint32_t new_ref_id = CloneOriginalReference(ref, &builder);
598   uint32_t null_id = 0;
599   uint32_t ref_type_id = ref->ref_inst->type_id();
600   (void)builder.AddBranch(merge_blk_id);
601   new_blocks->push_back(std::move(new_blk_ptr));
602   // Gen invalid block
603   new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
604   builder.SetInsertPoint(&*new_blk_ptr);
605 
606   // Generate a ConstantNull, converting to uint64 if the type cannot be a null.
607   if (new_ref_id != 0) {
608     analysis::TypeManager* type_mgr = context()->get_type_mgr();
609     analysis::Type* ref_type = type_mgr->GetType(ref_type_id);
610     if (ref_type->AsPointer() != nullptr) {
611       context()->AddCapability(spv::Capability::Int64);
612       uint32_t null_u64_id = GetNullId(GetUint64Id());
613       Instruction* null_ptr_inst = builder.AddUnaryOp(
614           ref_type_id, spv::Op::OpConvertUToPtr, null_u64_id);
615       null_id = null_ptr_inst->result_id();
616     } else {
617       null_id = GetNullId(ref_type_id);
618     }
619   }
620   // Remember last invalid block id
621   uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id();
622   // Gen zero for invalid  reference
623   (void)builder.AddBranch(merge_blk_id);
624   new_blocks->push_back(std::move(new_blk_ptr));
625   // Gen merge block
626   new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
627   builder.SetInsertPoint(&*new_blk_ptr);
628   // Gen phi of new reference and zero, if necessary, and replace the
629   // result id of the original reference with that of the Phi. Kill original
630   // reference.
631   if (new_ref_id != 0) {
632     Instruction* phi_inst = builder.AddPhi(
633         ref_type_id, {new_ref_id, valid_blk_id, null_id, last_invalid_blk_id});
634     context()->ReplaceAllUsesWith(ref->ref_inst->result_id(),
635                                   phi_inst->result_id());
636   }
637   new_blocks->push_back(std::move(new_blk_ptr));
638   context()->KillInst(ref->ref_inst);
639 }
640 
GenDescCheckCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,uint32_t stage_idx,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)641 void InstBindlessCheckPass::GenDescCheckCode(
642     BasicBlock::iterator ref_inst_itr,
643     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
644     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
645   // Look for reference through descriptor. If not, return.
646   RefAnalysis ref;
647   if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
648   std::unique_ptr<BasicBlock> new_blk_ptr;
649   // Move original block's preceding instructions into first new block
650   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
651   InstructionBuilder builder(
652       context(), &*new_blk_ptr,
653       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
654   new_blocks->push_back(std::move(new_blk_ptr));
655   // Determine if we can only do initialization check
656   uint32_t ref_id = builder.GetUintConstantId(0u);
657   spv::Op op = ref.ref_inst->opcode();
658   if (ref.desc_load_id != 0) {
659     uint32_t num_in_oprnds = ref.ref_inst->NumInOperands();
660     if ((op == spv::Op::OpImageRead && num_in_oprnds == 2) ||
661         (op == spv::Op::OpImageFetch && num_in_oprnds == 2) ||
662         (op == spv::Op::OpImageWrite && num_in_oprnds == 3)) {
663       Instruction* image_inst = get_def_use_mgr()->GetDef(ref.image_id);
664       uint32_t image_ty_id = image_inst->type_id();
665       Instruction* image_ty_inst = get_def_use_mgr()->GetDef(image_ty_id);
666       if (spv::Dim(image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDim)) ==
667           spv::Dim::Buffer) {
668         if ((image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDepth) == 0) &&
669             (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageArrayed) ==
670              0) &&
671             (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageMS) == 0)) {
672           ref_id = GenUintCastCode(ref.ref_inst->GetSingleWordInOperand(1),
673                                    &builder);
674         }
675       }
676     }
677   } else {
678     // For now, only do bounds check for non-aggregate types. Otherwise
679     // just do descriptor initialization check.
680     // TODO(greg-lunarg): Do bounds check for aggregate loads and stores
681     Instruction* ref_ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
682     Instruction* pte_type_inst = GetPointeeTypeInst(ref_ptr_inst);
683     spv::Op pte_type_op = pte_type_inst->opcode();
684     if (pte_type_op != spv::Op::OpTypeArray &&
685         pte_type_op != spv::Op::OpTypeRuntimeArray &&
686         pte_type_op != spv::Op::OpTypeStruct) {
687       ref_id = GenLastByteIdx(&ref, &builder);
688     }
689   }
690   // Read initialization/bounds from debug input buffer. If index id not yet
691   // set, binding is single descriptor, so set index to constant 0.
692   if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
693   uint32_t check_id =
694       GenDescCheckCall(ref.ref_inst->unique_id(), stage_idx, ref.var_id,
695                        ref.desc_idx_id, ref_id, &builder);
696 
697   // Generate runtime initialization/bounds test code with true branch
698   // being full reference and false branch being zero
699   // for the referenced value.
700   GenCheckCode(check_id, &ref, new_blocks);
701 
702   // Move original block's remaining code into remainder/merge block and add
703   // to new blocks
704   BasicBlock* back_blk_ptr = &*new_blocks->back();
705   MovePostludeCode(ref_block_itr, back_blk_ptr);
706 }
707 
InitializeInstBindlessCheck()708 void InstBindlessCheckPass::InitializeInstBindlessCheck() {
709   // Initialize base class
710   InitializeInstrument();
711   for (auto& anno : get_module()->annotations()) {
712     if (anno.opcode() == spv::Op::OpDecorate) {
713       if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
714           spv::Decoration::DescriptorSet) {
715         var2desc_set_[anno.GetSingleWordInOperand(0u)] =
716             anno.GetSingleWordInOperand(2u);
717       } else if (spv::Decoration(anno.GetSingleWordInOperand(1u)) ==
718                  spv::Decoration::Binding) {
719         var2binding_[anno.GetSingleWordInOperand(0u)] =
720             anno.GetSingleWordInOperand(2u);
721       }
722     }
723   }
724 }
725 
ProcessImpl()726 Pass::Status InstBindlessCheckPass::ProcessImpl() {
727   // The memory model and linkage must always be updated for spirv-link to work
728   // correctly.
729   AddStorageBufferExt();
730   if (!get_feature_mgr()->HasExtension(kSPV_KHR_physical_storage_buffer)) {
731     context()->AddExtension("SPV_KHR_physical_storage_buffer");
732   }
733 
734   context()->AddCapability(spv::Capability::PhysicalStorageBufferAddresses);
735   Instruction* memory_model = get_module()->GetMemoryModel();
736   memory_model->SetInOperand(
737       0u, {uint32_t(spv::AddressingModel::PhysicalStorageBuffer64)});
738 
739   context()->AddCapability(spv::Capability::Linkage);
740 
741   InstProcessFunction pfn =
742       [this](BasicBlock::iterator ref_inst_itr,
743              UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
744              std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
745         return GenDescCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
746                                 new_blocks);
747       };
748 
749   InstProcessEntryPointCallTree(pfn);
750   // This pass always changes the memory model, so that linking will work
751   // properly.
752   return Status::SuccessWithChange;
753 }
754 
Process()755 Pass::Status InstBindlessCheckPass::Process() {
756   InitializeInstBindlessCheck();
757   return ProcessImpl();
758 }
759 
760 }  // namespace opt
761 }  // namespace spvtools
762