• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 The Khronos Group Inc.
2 // Copyright (c) 2019 Valve Corporation
3 // Copyright (c) 2019 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "inst_buff_addr_check_pass.h"
18 
19 namespace spvtools {
20 namespace opt {
21 
CloneOriginalReference(Instruction * ref_inst,InstructionBuilder * builder)22 uint32_t InstBuffAddrCheckPass::CloneOriginalReference(
23     Instruction* ref_inst, InstructionBuilder* builder) {
24   // Clone original ref with new result id (if load)
25   assert(
26       (ref_inst->opcode() == SpvOpLoad || ref_inst->opcode() == SpvOpStore) &&
27       "unexpected ref");
28   std::unique_ptr<Instruction> new_ref_inst(ref_inst->Clone(context()));
29   uint32_t ref_result_id = ref_inst->result_id();
30   uint32_t new_ref_id = 0;
31   if (ref_result_id != 0) {
32     new_ref_id = TakeNextId();
33     new_ref_inst->SetResultId(new_ref_id);
34   }
35   // Register new reference and add to new block
36   Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
37   uid2offset_[added_inst->unique_id()] = uid2offset_[ref_inst->unique_id()];
38   if (new_ref_id != 0)
39     get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
40   return new_ref_id;
41 }
42 
IsPhysicalBuffAddrReference(Instruction * ref_inst)43 bool InstBuffAddrCheckPass::IsPhysicalBuffAddrReference(Instruction* ref_inst) {
44   if (ref_inst->opcode() != SpvOpLoad && ref_inst->opcode() != SpvOpStore)
45     return false;
46   uint32_t ptr_id = ref_inst->GetSingleWordInOperand(0);
47   analysis::DefUseManager* du_mgr = get_def_use_mgr();
48   Instruction* ptr_inst = du_mgr->GetDef(ptr_id);
49   if (ptr_inst->opcode() != SpvOpAccessChain) return false;
50   uint32_t ptr_ty_id = ptr_inst->type_id();
51   Instruction* ptr_ty_inst = du_mgr->GetDef(ptr_ty_id);
52   if (ptr_ty_inst->GetSingleWordInOperand(0) !=
53       SpvStorageClassPhysicalStorageBufferEXT)
54     return false;
55   return true;
56 }
57 
58 // TODO(greg-lunarg): Refactor with InstBindlessCheckPass::GenCheckCode() ??
GenCheckCode(uint32_t check_id,uint32_t error_id,uint32_t ref_uptr_id,uint32_t stage_idx,Instruction * ref_inst,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)59 void InstBuffAddrCheckPass::GenCheckCode(
60     uint32_t check_id, uint32_t error_id, uint32_t ref_uptr_id,
61     uint32_t stage_idx, Instruction* ref_inst,
62     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
63   BasicBlock* back_blk_ptr = &*new_blocks->back();
64   InstructionBuilder builder(
65       context(), back_blk_ptr,
66       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
67   // Gen conditional branch on check_id. Valid branch generates original
68   // reference. Invalid generates debug output and zero result (if needed).
69   uint32_t merge_blk_id = TakeNextId();
70   uint32_t valid_blk_id = TakeNextId();
71   uint32_t invalid_blk_id = TakeNextId();
72   std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
73   std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
74   std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
75   (void)builder.AddConditionalBranch(check_id, valid_blk_id, invalid_blk_id,
76                                      merge_blk_id, SpvSelectionControlMaskNone);
77   // Gen valid branch
78   std::unique_ptr<BasicBlock> new_blk_ptr(
79       new BasicBlock(std::move(valid_label)));
80   builder.SetInsertPoint(&*new_blk_ptr);
81   uint32_t new_ref_id = CloneOriginalReference(ref_inst, &builder);
82   (void)builder.AddBranch(merge_blk_id);
83   new_blocks->push_back(std::move(new_blk_ptr));
84   // Gen invalid block
85   new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
86   builder.SetInsertPoint(&*new_blk_ptr);
87   // Convert uptr from uint64 to 2 uint32
88   Instruction* lo_uptr_inst =
89       builder.AddUnaryOp(GetUintId(), SpvOpUConvert, ref_uptr_id);
90   Instruction* rshift_uptr_inst =
91       builder.AddBinaryOp(GetUint64Id(), SpvOpShiftRightLogical, ref_uptr_id,
92                           builder.GetUintConstantId(32));
93   Instruction* hi_uptr_inst = builder.AddUnaryOp(GetUintId(), SpvOpUConvert,
94                                                  rshift_uptr_inst->result_id());
95   GenDebugStreamWrite(
96       uid2offset_[ref_inst->unique_id()], stage_idx,
97       {error_id, lo_uptr_inst->result_id(), hi_uptr_inst->result_id()},
98       &builder);
99   // Gen zero for invalid load. If pointer type, need to convert uint64
100   // zero to pointer; cannot create ConstantNull of pointer type.
101   uint32_t null_id = 0;
102   if (new_ref_id != 0) {
103     uint32_t ref_type_id = ref_inst->type_id();
104     analysis::TypeManager* type_mgr = context()->get_type_mgr();
105     analysis::Type* ref_type = type_mgr->GetType(ref_type_id);
106     if (ref_type->AsPointer() != nullptr) {
107       uint32_t null_u64_id = GetNullId(GetUint64Id());
108       Instruction* null_ptr_inst =
109           builder.AddUnaryOp(ref_type_id, SpvOpConvertUToPtr, null_u64_id);
110       null_id = null_ptr_inst->result_id();
111     } else {
112       null_id = GetNullId(ref_type_id);
113     }
114   }
115   (void)builder.AddBranch(merge_blk_id);
116   new_blocks->push_back(std::move(new_blk_ptr));
117   // Gen merge block
118   new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
119   builder.SetInsertPoint(&*new_blk_ptr);
120   // Gen phi of new reference and zero, if necessary, and replace the
121   // result id of the original reference with that of the Phi. Kill original
122   // reference.
123   if (new_ref_id != 0) {
124     Instruction* phi_inst =
125         builder.AddPhi(ref_inst->type_id(),
126                        {new_ref_id, valid_blk_id, null_id, invalid_blk_id});
127     context()->ReplaceAllUsesWith(ref_inst->result_id(), phi_inst->result_id());
128   }
129   new_blocks->push_back(std::move(new_blk_ptr));
130   context()->KillInst(ref_inst);
131 }
132 
GetTypeAlignment(uint32_t type_id)133 uint32_t InstBuffAddrCheckPass::GetTypeAlignment(uint32_t type_id) {
134   Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
135   switch (type_inst->opcode()) {
136     case SpvOpTypeFloat:
137     case SpvOpTypeInt:
138     case SpvOpTypeVector:
139       return GetTypeLength(type_id);
140     case SpvOpTypeMatrix:
141       return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
142     case SpvOpTypeArray:
143     case SpvOpTypeRuntimeArray:
144       return GetTypeAlignment(type_inst->GetSingleWordInOperand(0));
145     case SpvOpTypeStruct: {
146       uint32_t max = 0;
147       type_inst->ForEachInId([&max, this](const uint32_t* iid) {
148         uint32_t alignment = GetTypeAlignment(*iid);
149         max = (alignment > max) ? alignment : max;
150       });
151       return max;
152     }
153     case SpvOpTypePointer:
154       assert(type_inst->GetSingleWordInOperand(0) ==
155                  SpvStorageClassPhysicalStorageBufferEXT &&
156              "unexpected pointer type");
157       return 8u;
158     default:
159       assert(false && "unexpected type");
160       return 0;
161   }
162 }
163 
GetTypeLength(uint32_t type_id)164 uint32_t InstBuffAddrCheckPass::GetTypeLength(uint32_t type_id) {
165   Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
166   switch (type_inst->opcode()) {
167     case SpvOpTypeFloat:
168     case SpvOpTypeInt:
169       return type_inst->GetSingleWordInOperand(0) / 8u;
170     case SpvOpTypeVector: {
171       uint32_t raw_cnt = type_inst->GetSingleWordInOperand(1);
172       uint32_t adj_cnt = (raw_cnt == 3u) ? 4u : raw_cnt;
173       return adj_cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
174     }
175     case SpvOpTypeMatrix:
176       return type_inst->GetSingleWordInOperand(1) *
177              GetTypeLength(type_inst->GetSingleWordInOperand(0));
178     case SpvOpTypePointer:
179       assert(type_inst->GetSingleWordInOperand(0) ==
180                  SpvStorageClassPhysicalStorageBufferEXT &&
181              "unexpected pointer type");
182       return 8u;
183     case SpvOpTypeArray: {
184       uint32_t const_id = type_inst->GetSingleWordInOperand(1);
185       Instruction* const_inst = get_def_use_mgr()->GetDef(const_id);
186       uint32_t cnt = const_inst->GetSingleWordInOperand(0);
187       return cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
188     }
189     case SpvOpTypeStruct: {
190       uint32_t len = 0;
191       type_inst->ForEachInId([&len, this](const uint32_t* iid) {
192         // Align struct length
193         uint32_t alignment = GetTypeAlignment(*iid);
194         uint32_t mod = len % alignment;
195         uint32_t diff = (mod != 0) ? alignment - mod : 0;
196         len += diff;
197         // Increment struct length by component length
198         uint32_t comp_len = GetTypeLength(*iid);
199         len += comp_len;
200       });
201       return len;
202     }
203     case SpvOpTypeRuntimeArray:
204     default:
205       assert(false && "unexpected type");
206       return 0;
207   }
208 }
209 
AddParam(uint32_t type_id,std::vector<uint32_t> * param_vec,std::unique_ptr<Function> * input_func)210 void InstBuffAddrCheckPass::AddParam(uint32_t type_id,
211                                      std::vector<uint32_t>* param_vec,
212                                      std::unique_ptr<Function>* input_func) {
213   uint32_t pid = TakeNextId();
214   param_vec->push_back(pid);
215   std::unique_ptr<Instruction> param_inst(new Instruction(
216       get_module()->context(), SpvOpFunctionParameter, type_id, pid, {}));
217   get_def_use_mgr()->AnalyzeInstDefUse(&*param_inst);
218   (*input_func)->AddParameter(std::move(param_inst));
219 }
220 
GetSearchAndTestFuncId()221 uint32_t InstBuffAddrCheckPass::GetSearchAndTestFuncId() {
222   if (search_test_func_id_ == 0) {
223     // Generate function "bool search_and_test(uint64_t ref_ptr, uint32_t len)"
224     // which searches input buffer for buffer which most likely contains the
225     // pointer value |ref_ptr| and verifies that the entire reference of
226     // length |len| bytes is contained in the buffer.
227     search_test_func_id_ = TakeNextId();
228     analysis::TypeManager* type_mgr = context()->get_type_mgr();
229     std::vector<const analysis::Type*> param_types = {
230         type_mgr->GetType(GetUint64Id()), type_mgr->GetType(GetUintId())};
231     analysis::Function func_ty(type_mgr->GetType(GetBoolId()), param_types);
232     analysis::Type* reg_func_ty = type_mgr->GetRegisteredType(&func_ty);
233     std::unique_ptr<Instruction> func_inst(
234         new Instruction(get_module()->context(), SpvOpFunction, GetBoolId(),
235                         search_test_func_id_,
236                         {{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
237                           {SpvFunctionControlMaskNone}},
238                          {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
239                           {type_mgr->GetTypeInstruction(reg_func_ty)}}}));
240     get_def_use_mgr()->AnalyzeInstDefUse(&*func_inst);
241     std::unique_ptr<Function> input_func =
242         MakeUnique<Function>(std::move(func_inst));
243     std::vector<uint32_t> param_vec;
244     // Add ref_ptr and length parameters
245     AddParam(GetUint64Id(), &param_vec, &input_func);
246     AddParam(GetUintId(), &param_vec, &input_func);
247     // Empty first block.
248     uint32_t first_blk_id = TakeNextId();
249     std::unique_ptr<Instruction> first_blk_label(NewLabel(first_blk_id));
250     std::unique_ptr<BasicBlock> first_blk_ptr =
251         MakeUnique<BasicBlock>(std::move(first_blk_label));
252     InstructionBuilder builder(
253         context(), &*first_blk_ptr,
254         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
255     uint32_t hdr_blk_id = TakeNextId();
256     // Branch to search loop header
257     std::unique_ptr<Instruction> hdr_blk_label(NewLabel(hdr_blk_id));
258     (void)builder.AddInstruction(MakeUnique<Instruction>(
259         context(), SpvOpBranch, 0, 0,
260         std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {hdr_blk_id}}}));
261     input_func->AddBasicBlock(std::move(first_blk_ptr));
262     // Linear search loop header block
263     // TODO(greg-lunarg): Implement binary search
264     std::unique_ptr<BasicBlock> hdr_blk_ptr =
265         MakeUnique<BasicBlock>(std::move(hdr_blk_label));
266     builder.SetInsertPoint(&*hdr_blk_ptr);
267     // Phi for search index. Starts with 1.
268     uint32_t cont_blk_id = TakeNextId();
269     std::unique_ptr<Instruction> cont_blk_label(NewLabel(cont_blk_id));
270     // Deal with def-use cycle caused by search loop index computation.
271     // Create Add and Phi instructions first, then do Def analysis on Add.
272     // Add Phi and Add instructions and do Use analysis later.
273     uint32_t idx_phi_id = TakeNextId();
274     uint32_t idx_inc_id = TakeNextId();
275     std::unique_ptr<Instruction> idx_inc_inst(new Instruction(
276         context(), SpvOpIAdd, GetUintId(), idx_inc_id,
277         {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_phi_id}},
278          {spv_operand_type_t::SPV_OPERAND_TYPE_ID,
279           {builder.GetUintConstantId(1u)}}}));
280     std::unique_ptr<Instruction> idx_phi_inst(new Instruction(
281         context(), SpvOpPhi, GetUintId(), idx_phi_id,
282         {{spv_operand_type_t::SPV_OPERAND_TYPE_ID,
283           {builder.GetUintConstantId(1u)}},
284          {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {first_blk_id}},
285          {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {idx_inc_id}},
286          {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cont_blk_id}}}));
287     get_def_use_mgr()->AnalyzeInstDef(&*idx_inc_inst);
288     // Add (previously created) search index phi
289     (void)builder.AddInstruction(std::move(idx_phi_inst));
290     // LoopMerge
291     uint32_t bound_test_blk_id = TakeNextId();
292     std::unique_ptr<Instruction> bound_test_blk_label(
293         NewLabel(bound_test_blk_id));
294     (void)builder.AddInstruction(MakeUnique<Instruction>(
295         context(), SpvOpLoopMerge, 0, 0,
296         std::initializer_list<Operand>{
297             {SPV_OPERAND_TYPE_ID, {bound_test_blk_id}},
298             {SPV_OPERAND_TYPE_ID, {cont_blk_id}},
299             {SPV_OPERAND_TYPE_LITERAL_INTEGER, {SpvLoopControlMaskNone}}}));
300     // Branch to continue/work block
301     (void)builder.AddInstruction(MakeUnique<Instruction>(
302         context(), SpvOpBranch, 0, 0,
303         std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {cont_blk_id}}}));
304     input_func->AddBasicBlock(std::move(hdr_blk_ptr));
305     // Continue/Work Block. Read next buffer pointer and break if greater
306     // than ref_ptr arg.
307     std::unique_ptr<BasicBlock> cont_blk_ptr =
308         MakeUnique<BasicBlock>(std::move(cont_blk_label));
309     builder.SetInsertPoint(&*cont_blk_ptr);
310     // Add (previously created) search index increment now.
311     (void)builder.AddInstruction(std::move(idx_inc_inst));
312     // Load next buffer address from debug input buffer
313     uint32_t ibuf_id = GetInputBufferId();
314     uint32_t ibuf_ptr_id = GetInputBufferPtrId();
315     Instruction* uptr_ac_inst = builder.AddTernaryOp(
316         ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
317         builder.GetUintConstantId(kDebugInputDataOffset), idx_inc_id);
318     uint32_t ibuf_type_id = GetInputBufferTypeId();
319     Instruction* uptr_load_inst =
320         builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, uptr_ac_inst->result_id());
321     // If loaded address greater than ref_ptr arg, break, else branch back to
322     // loop header
323     Instruction* uptr_test_inst =
324         builder.AddBinaryOp(GetBoolId(), SpvOpUGreaterThan,
325                             uptr_load_inst->result_id(), param_vec[0]);
326     (void)builder.AddConditionalBranch(uptr_test_inst->result_id(),
327                                        bound_test_blk_id, hdr_blk_id,
328                                        kInvalidId, SpvSelectionControlMaskNone);
329     input_func->AddBasicBlock(std::move(cont_blk_ptr));
330     // Bounds test block. Read length of selected buffer and test that
331     // all len arg bytes are in buffer.
332     std::unique_ptr<BasicBlock> bound_test_blk_ptr =
333         MakeUnique<BasicBlock>(std::move(bound_test_blk_label));
334     builder.SetInsertPoint(&*bound_test_blk_ptr);
335     // Decrement index to point to previous/candidate buffer address
336     Instruction* cand_idx_inst = builder.AddBinaryOp(
337         GetUintId(), SpvOpISub, idx_inc_id, builder.GetUintConstantId(1u));
338     // Load candidate buffer address
339     Instruction* cand_ac_inst =
340         builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
341                              builder.GetUintConstantId(kDebugInputDataOffset),
342                              cand_idx_inst->result_id());
343     Instruction* cand_load_inst =
344         builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, cand_ac_inst->result_id());
345     // Compute offset of ref_ptr from candidate buffer address
346     Instruction* offset_inst = builder.AddBinaryOp(
347         ibuf_type_id, SpvOpISub, param_vec[0], cand_load_inst->result_id());
348     // Convert ref length to uint64
349     Instruction* ref_len_64_inst =
350         builder.AddUnaryOp(ibuf_type_id, SpvOpUConvert, param_vec[1]);
351     // Add ref length to ref offset to compute end of reference
352     Instruction* ref_end_inst =
353         builder.AddBinaryOp(ibuf_type_id, SpvOpIAdd, offset_inst->result_id(),
354                             ref_len_64_inst->result_id());
355     // Load starting index of lengths in input buffer and convert to uint32
356     Instruction* len_start_ac_inst =
357         builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
358                              builder.GetUintConstantId(kDebugInputDataOffset),
359                              builder.GetUintConstantId(0u));
360     Instruction* len_start_load_inst = builder.AddUnaryOp(
361         ibuf_type_id, SpvOpLoad, len_start_ac_inst->result_id());
362     Instruction* len_start_32_inst = builder.AddUnaryOp(
363         GetUintId(), SpvOpUConvert, len_start_load_inst->result_id());
364     // Decrement search index to get candidate buffer length index
365     Instruction* cand_len_idx_inst =
366         builder.AddBinaryOp(GetUintId(), SpvOpISub, cand_idx_inst->result_id(),
367                             builder.GetUintConstantId(1u));
368     // Add candidate length index to start index
369     Instruction* len_idx_inst = builder.AddBinaryOp(
370         GetUintId(), SpvOpIAdd, cand_len_idx_inst->result_id(),
371         len_start_32_inst->result_id());
372     // Load candidate buffer length
373     Instruction* len_ac_inst =
374         builder.AddTernaryOp(ibuf_ptr_id, SpvOpAccessChain, ibuf_id,
375                              builder.GetUintConstantId(kDebugInputDataOffset),
376                              len_idx_inst->result_id());
377     Instruction* len_load_inst =
378         builder.AddUnaryOp(ibuf_type_id, SpvOpLoad, len_ac_inst->result_id());
379     // Test if reference end within candidate buffer length
380     Instruction* len_test_inst = builder.AddBinaryOp(
381         GetBoolId(), SpvOpULessThanEqual, ref_end_inst->result_id(),
382         len_load_inst->result_id());
383     // Return test result
384     (void)builder.AddInstruction(MakeUnique<Instruction>(
385         context(), SpvOpReturnValue, 0, 0,
386         std::initializer_list<Operand>{
387             {SPV_OPERAND_TYPE_ID, {len_test_inst->result_id()}}}));
388     // Close block
389     input_func->AddBasicBlock(std::move(bound_test_blk_ptr));
390     // Close function and add function to module
391     std::unique_ptr<Instruction> func_end_inst(
392         new Instruction(get_module()->context(), SpvOpFunctionEnd, 0, 0, {}));
393     get_def_use_mgr()->AnalyzeInstDefUse(&*func_end_inst);
394     input_func->SetFunctionEnd(std::move(func_end_inst));
395     context()->AddFunction(std::move(input_func));
396     context()->AddDebug2Inst(
397         NewGlobalName(search_test_func_id_, "search_and_test"));
398   }
399   return search_test_func_id_;
400 }
401 
GenSearchAndTest(Instruction * ref_inst,InstructionBuilder * builder,uint32_t * ref_uptr_id)402 uint32_t InstBuffAddrCheckPass::GenSearchAndTest(Instruction* ref_inst,
403                                                  InstructionBuilder* builder,
404                                                  uint32_t* ref_uptr_id) {
405   // Enable Int64 if necessary
406   if (!get_feature_mgr()->HasCapability(SpvCapabilityInt64)) {
407     std::unique_ptr<Instruction> cap_int64_inst(new Instruction(
408         context(), SpvOpCapability, 0, 0,
409         std::initializer_list<Operand>{
410             {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityInt64}}}));
411     get_def_use_mgr()->AnalyzeInstDefUse(&*cap_int64_inst);
412     context()->AddCapability(std::move(cap_int64_inst));
413   }
414   // Convert reference pointer to uint64
415   uint32_t ref_ptr_id = ref_inst->GetSingleWordInOperand(0);
416   Instruction* ref_uptr_inst =
417       builder->AddUnaryOp(GetUint64Id(), SpvOpConvertPtrToU, ref_ptr_id);
418   *ref_uptr_id = ref_uptr_inst->result_id();
419   // Compute reference length in bytes
420   analysis::DefUseManager* du_mgr = get_def_use_mgr();
421   Instruction* ref_ptr_inst = du_mgr->GetDef(ref_ptr_id);
422   uint32_t ref_ptr_ty_id = ref_ptr_inst->type_id();
423   Instruction* ref_ptr_ty_inst = du_mgr->GetDef(ref_ptr_ty_id);
424   uint32_t ref_len = GetTypeLength(ref_ptr_ty_inst->GetSingleWordInOperand(1));
425   uint32_t ref_len_id = builder->GetUintConstantId(ref_len);
426   // Gen call to search and test function
427   const std::vector<uint32_t> args = {GetSearchAndTestFuncId(), *ref_uptr_id,
428                                       ref_len_id};
429   Instruction* call_inst =
430       builder->AddNaryOp(GetBoolId(), SpvOpFunctionCall, args);
431   uint32_t retval = call_inst->result_id();
432   return retval;
433 }
434 
GenBuffAddrCheckCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,uint32_t stage_idx,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)435 void InstBuffAddrCheckPass::GenBuffAddrCheckCode(
436     BasicBlock::iterator ref_inst_itr,
437     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
438     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
439   // Look for reference through indexed descriptor. If found, analyze and
440   // save components. If not, return.
441   Instruction* ref_inst = &*ref_inst_itr;
442   if (!IsPhysicalBuffAddrReference(ref_inst)) return;
443   // Move original block's preceding instructions into first new block
444   std::unique_ptr<BasicBlock> new_blk_ptr;
445   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
446   InstructionBuilder builder(
447       context(), &*new_blk_ptr,
448       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
449   new_blocks->push_back(std::move(new_blk_ptr));
450   uint32_t error_id = builder.GetUintConstantId(kInstErrorBuffAddrUnallocRef);
451   // Generate code to do search and test if all bytes of reference
452   // are within a listed buffer. Return reference pointer converted to uint64.
453   uint32_t ref_uptr_id;
454   uint32_t valid_id = GenSearchAndTest(ref_inst, &builder, &ref_uptr_id);
455   // Generate test of search results with true branch
456   // being full reference and false branch being debug output and zero
457   // for the referenced value.
458   GenCheckCode(valid_id, error_id, ref_uptr_id, stage_idx, ref_inst,
459                new_blocks);
460   // Move original block's remaining code into remainder/merge block and add
461   // to new blocks
462   BasicBlock* back_blk_ptr = &*new_blocks->back();
463   MovePostludeCode(ref_block_itr, back_blk_ptr);
464 }
465 
InitInstBuffAddrCheck()466 void InstBuffAddrCheckPass::InitInstBuffAddrCheck() {
467   // Initialize base class
468   InitializeInstrument();
469   // Initialize class
470   search_test_func_id_ = 0;
471 }
472 
ProcessImpl()473 Pass::Status InstBuffAddrCheckPass::ProcessImpl() {
474   // Perform bindless bounds check on each entry point function in module
475   InstProcessFunction pfn =
476       [this](BasicBlock::iterator ref_inst_itr,
477              UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
478              std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
479         return GenBuffAddrCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
480                                     new_blocks);
481       };
482   bool modified = InstProcessEntryPointCallTree(pfn);
483   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
484 }
485 
Process()486 Pass::Status InstBuffAddrCheckPass::Process() {
487   if (!get_feature_mgr()->HasCapability(
488           SpvCapabilityPhysicalStorageBufferAddressesEXT))
489     return Status::SuccessWithoutChange;
490   InitInstBuffAddrCheck();
491   return ProcessImpl();
492 }
493 
494 }  // namespace opt
495 }  // namespace spvtools
496