• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 The Khronos Group Inc.
2 // Copyright (c) 2019 Valve Corporation
3 // Copyright (c) 2019 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "inst_buff_addr_check_pass.h"
18 
19 namespace spvtools {
20 namespace opt {
21 
CloneOriginalReference(Instruction * ref_inst,InstructionBuilder * builder)22 uint32_t InstBuffAddrCheckPass::CloneOriginalReference(
23     Instruction* ref_inst, InstructionBuilder* builder) {
24   // Clone original ref with new result id (if load)
25   assert((ref_inst->opcode() == spv::Op::OpLoad ||
26           ref_inst->opcode() == spv::Op::OpStore) &&
27          "unexpected ref");
28   std::unique_ptr<Instruction> new_ref_inst(ref_inst->Clone(context()));
29   uint32_t ref_result_id = ref_inst->result_id();
30   uint32_t new_ref_id = 0;
31   if (ref_result_id != 0) {
32     new_ref_id = TakeNextId();
33     new_ref_inst->SetResultId(new_ref_id);
34   }
35   // Register new reference and add to new block
36   Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
37   uid2offset_[added_inst->unique_id()] = uid2offset_[ref_inst->unique_id()];
38   if (new_ref_id != 0)
39     get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
40   return new_ref_id;
41 }
42 
IsPhysicalBuffAddrReference(Instruction * ref_inst)43 bool InstBuffAddrCheckPass::IsPhysicalBuffAddrReference(Instruction* ref_inst) {
44   if (ref_inst->opcode() != spv::Op::OpLoad &&
45       ref_inst->opcode() != spv::Op::OpStore)
46     return false;
47   uint32_t ptr_id = ref_inst->GetSingleWordInOperand(0);
48   analysis::DefUseManager* du_mgr = get_def_use_mgr();
49   Instruction* ptr_inst = du_mgr->GetDef(ptr_id);
50   if (ptr_inst->opcode() != spv::Op::OpAccessChain) return false;
51   uint32_t ptr_ty_id = ptr_inst->type_id();
52   Instruction* ptr_ty_inst = du_mgr->GetDef(ptr_ty_id);
53   if (spv::StorageClass(ptr_ty_inst->GetSingleWordInOperand(0)) !=
54       spv::StorageClass::PhysicalStorageBufferEXT)
55     return false;
56   return true;
57 }
58 
59 // TODO(greg-lunarg): Refactor with InstBindlessCheckPass::GenCheckCode() ??
GenCheckCode(uint32_t check_id,Instruction * ref_inst,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)60 void InstBuffAddrCheckPass::GenCheckCode(
61     uint32_t check_id, Instruction* ref_inst,
62     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
63   BasicBlock* back_blk_ptr = &*new_blocks->back();
64   InstructionBuilder builder(
65       context(), back_blk_ptr,
66       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
67   // Gen conditional branch on check_id. Valid branch generates original
68   // reference. Invalid generates debug output and zero result (if needed).
69   uint32_t merge_blk_id = TakeNextId();
70   uint32_t valid_blk_id = TakeNextId();
71   uint32_t invalid_blk_id = TakeNextId();
72   std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
73   std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
74   std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
75   (void)builder.AddConditionalBranch(
76       check_id, valid_blk_id, invalid_blk_id, merge_blk_id,
77       uint32_t(spv::SelectionControlMask::MaskNone));
78   // Gen valid branch
79   std::unique_ptr<BasicBlock> new_blk_ptr(
80       new BasicBlock(std::move(valid_label)));
81   builder.SetInsertPoint(&*new_blk_ptr);
82   uint32_t new_ref_id = CloneOriginalReference(ref_inst, &builder);
83   (void)builder.AddBranch(merge_blk_id);
84   new_blocks->push_back(std::move(new_blk_ptr));
85   // Gen invalid block
86   new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
87   builder.SetInsertPoint(&*new_blk_ptr);
88   // Gen zero for invalid load. If pointer type, need to convert uint64
89   // zero to pointer; cannot create ConstantNull of pointer type.
90   uint32_t null_id = 0;
91   if (new_ref_id != 0) {
92     uint32_t ref_type_id = ref_inst->type_id();
93     analysis::TypeManager* type_mgr = context()->get_type_mgr();
94     analysis::Type* ref_type = type_mgr->GetType(ref_type_id);
95     if (ref_type->AsPointer() != nullptr) {
96       uint32_t null_u64_id = GetNullId(GetUint64Id());
97       Instruction* null_ptr_inst = builder.AddUnaryOp(
98           ref_type_id, spv::Op::OpConvertUToPtr, null_u64_id);
99       null_id = null_ptr_inst->result_id();
100     } else {
101       null_id = GetNullId(ref_type_id);
102     }
103   }
104   (void)builder.AddBranch(merge_blk_id);
105   new_blocks->push_back(std::move(new_blk_ptr));
106   // Gen merge block
107   new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
108   builder.SetInsertPoint(&*new_blk_ptr);
109   // Gen phi of new reference and zero, if necessary, and replace the
110   // result id of the original reference with that of the Phi. Kill original
111   // reference.
112   if (new_ref_id != 0) {
113     Instruction* phi_inst =
114         builder.AddPhi(ref_inst->type_id(),
115                        {new_ref_id, valid_blk_id, null_id, invalid_blk_id});
116     context()->ReplaceAllUsesWith(ref_inst->result_id(), phi_inst->result_id());
117   }
118   new_blocks->push_back(std::move(new_blk_ptr));
119   context()->KillInst(ref_inst);
120 }
121 
GetTypeLength(uint32_t type_id)122 uint32_t InstBuffAddrCheckPass::GetTypeLength(uint32_t type_id) {
123   Instruction* type_inst = get_def_use_mgr()->GetDef(type_id);
124   switch (type_inst->opcode()) {
125     case spv::Op::OpTypeFloat:
126     case spv::Op::OpTypeInt:
127       return type_inst->GetSingleWordInOperand(0) / 8u;
128     case spv::Op::OpTypeVector:
129     case spv::Op::OpTypeMatrix:
130       return type_inst->GetSingleWordInOperand(1) *
131              GetTypeLength(type_inst->GetSingleWordInOperand(0));
132     case spv::Op::OpTypePointer:
133       assert(spv::StorageClass(type_inst->GetSingleWordInOperand(0)) ==
134                  spv::StorageClass::PhysicalStorageBufferEXT &&
135              "unexpected pointer type");
136       return 8u;
137     case spv::Op::OpTypeArray: {
138       uint32_t const_id = type_inst->GetSingleWordInOperand(1);
139       Instruction* const_inst = get_def_use_mgr()->GetDef(const_id);
140       uint32_t cnt = const_inst->GetSingleWordInOperand(0);
141       return cnt * GetTypeLength(type_inst->GetSingleWordInOperand(0));
142     }
143     case spv::Op::OpTypeStruct: {
144       // Figure out the location of the last byte of the last member of the
145       // structure.
146       uint32_t last_offset = 0, last_len = 0;
147 
148       get_decoration_mgr()->ForEachDecoration(
149           type_id, uint32_t(spv::Decoration::Offset),
150           [&last_offset](const Instruction& deco_inst) {
151             last_offset = deco_inst.GetSingleWordInOperand(3);
152           });
153       type_inst->ForEachInId([&last_len, this](const uint32_t* iid) {
154         last_len = GetTypeLength(*iid);
155       });
156       return last_offset + last_len;
157     }
158     case spv::Op::OpTypeRuntimeArray:
159     default:
160       assert(false && "unexpected type");
161       return 0;
162   }
163 }
164 
AddParam(uint32_t type_id,std::vector<uint32_t> * param_vec,std::unique_ptr<Function> * input_func)165 void InstBuffAddrCheckPass::AddParam(uint32_t type_id,
166                                      std::vector<uint32_t>* param_vec,
167                                      std::unique_ptr<Function>* input_func) {
168   uint32_t pid = TakeNextId();
169   param_vec->push_back(pid);
170   std::unique_ptr<Instruction> param_inst(new Instruction(
171       get_module()->context(), spv::Op::OpFunctionParameter, type_id, pid, {}));
172   get_def_use_mgr()->AnalyzeInstDefUse(&*param_inst);
173   (*input_func)->AddParameter(std::move(param_inst));
174 }
175 
176 // This is a stub function for use with Import linkage
177 // clang-format off
178 // GLSL:
179 //bool inst_bindless_search_and_test(const uint shader_id, const uint inst_num, const uvec4 stage_info,
180 //				     const uint64 ref_ptr, const uint length) {
181 //}
182 // clang-format on
GetSearchAndTestFuncId()183 uint32_t InstBuffAddrCheckPass::GetSearchAndTestFuncId() {
184   enum {
185     kShaderId = 0,
186     kInstructionIndex = 1,
187     kStageInfo = 2,
188     kRefPtr = 3,
189     kLength = 4,
190     kNumArgs
191   };
192   if (search_test_func_id_ != 0) {
193     return search_test_func_id_;
194   }
195   // Generate function "bool search_and_test(uint64_t ref_ptr, uint32_t len)"
196   // which searches input buffer for buffer which most likely contains the
197   // pointer value |ref_ptr| and verifies that the entire reference of
198   // length |len| bytes is contained in the buffer.
199   analysis::TypeManager* type_mgr = context()->get_type_mgr();
200   const analysis::Integer* uint_type = GetInteger(32, false);
201   const analysis::Vector v4uint(uint_type, 4);
202   const analysis::Type* v4uint_type = type_mgr->GetRegisteredType(&v4uint);
203 
204   std::vector<const analysis::Type*> param_types = {
205       uint_type, uint_type, v4uint_type, type_mgr->GetType(GetUint64Id()),
206       uint_type};
207 
208   const std::string func_name{"inst_buff_addr_search_and_test"};
209   const uint32_t func_id = TakeNextId();
210   std::unique_ptr<Function> func =
211       StartFunction(func_id, type_mgr->GetBoolType(), param_types);
212   func->SetFunctionEnd(EndFunction());
213   context()->AddFunctionDeclaration(std::move(func));
214   context()->AddDebug2Inst(NewName(func_id, func_name));
215 
216   std::vector<Operand> operands{
217       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {func_id}},
218       {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
219        {uint32_t(spv::Decoration::LinkageAttributes)}},
220       {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_STRING,
221        utils::MakeVector(func_name.c_str())},
222       {spv_operand_type_t::SPV_OPERAND_TYPE_LINKAGE_TYPE,
223        {uint32_t(spv::LinkageType::Import)}},
224   };
225   get_decoration_mgr()->AddDecoration(spv::Op::OpDecorate, operands);
226 
227   search_test_func_id_ = func_id;
228   return search_test_func_id_;
229 }
230 
GenSearchAndTest(Instruction * ref_inst,InstructionBuilder * builder,uint32_t * ref_uptr_id,uint32_t stage_idx)231 uint32_t InstBuffAddrCheckPass::GenSearchAndTest(Instruction* ref_inst,
232                                                  InstructionBuilder* builder,
233                                                  uint32_t* ref_uptr_id,
234                                                  uint32_t stage_idx) {
235   // Enable Int64 if necessary
236   // Convert reference pointer to uint64
237   const uint32_t ref_ptr_id = ref_inst->GetSingleWordInOperand(0);
238   Instruction* ref_uptr_inst =
239       builder->AddUnaryOp(GetUint64Id(), spv::Op::OpConvertPtrToU, ref_ptr_id);
240   *ref_uptr_id = ref_uptr_inst->result_id();
241   // Compute reference length in bytes
242   analysis::DefUseManager* du_mgr = get_def_use_mgr();
243   Instruction* ref_ptr_inst = du_mgr->GetDef(ref_ptr_id);
244   const uint32_t ref_ptr_ty_id = ref_ptr_inst->type_id();
245   Instruction* ref_ptr_ty_inst = du_mgr->GetDef(ref_ptr_ty_id);
246   const uint32_t ref_len =
247       GetTypeLength(ref_ptr_ty_inst->GetSingleWordInOperand(1));
248   // Gen call to search and test function
249   const uint32_t func_id = GetSearchAndTestFuncId();
250   const std::vector<uint32_t> args = {
251       builder->GetUintConstantId(shader_id_),
252       builder->GetUintConstantId(ref_inst->unique_id()),
253       GenStageInfo(stage_idx, builder), *ref_uptr_id,
254       builder->GetUintConstantId(ref_len)};
255   return GenReadFunctionCall(GetBoolId(), func_id, args, builder);
256 }
257 
GenBuffAddrCheckCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,uint32_t stage_idx,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)258 void InstBuffAddrCheckPass::GenBuffAddrCheckCode(
259     BasicBlock::iterator ref_inst_itr,
260     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
261     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
262   // Look for reference through indexed descriptor. If found, analyze and
263   // save components. If not, return.
264   Instruction* ref_inst = &*ref_inst_itr;
265   if (!IsPhysicalBuffAddrReference(ref_inst)) return;
266   // Move original block's preceding instructions into first new block
267   std::unique_ptr<BasicBlock> new_blk_ptr;
268   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
269   InstructionBuilder builder(
270       context(), &*new_blk_ptr,
271       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
272   new_blocks->push_back(std::move(new_blk_ptr));
273   // Generate code to do search and test if all bytes of reference
274   // are within a listed buffer. Return reference pointer converted to uint64.
275   uint32_t ref_uptr_id;
276   uint32_t valid_id =
277       GenSearchAndTest(ref_inst, &builder, &ref_uptr_id, stage_idx);
278   // Generate test of search results with true branch
279   // being full reference and false branch being debug output and zero
280   // for the referenced value.
281   GenCheckCode(valid_id, ref_inst, new_blocks);
282 
283   // Move original block's remaining code into remainder/merge block and add
284   // to new blocks
285   BasicBlock* back_blk_ptr = &*new_blocks->back();
286   MovePostludeCode(ref_block_itr, back_blk_ptr);
287 }
288 
InitInstBuffAddrCheck()289 void InstBuffAddrCheckPass::InitInstBuffAddrCheck() {
290   // Initialize base class
291   InitializeInstrument();
292   // Initialize class
293   search_test_func_id_ = 0;
294 }
295 
ProcessImpl()296 Pass::Status InstBuffAddrCheckPass::ProcessImpl() {
297   // The memory model and linkage must always be updated for spirv-link to work
298   // correctly.
299   AddStorageBufferExt();
300   if (!get_feature_mgr()->HasExtension(kSPV_KHR_physical_storage_buffer)) {
301     context()->AddExtension("SPV_KHR_physical_storage_buffer");
302   }
303 
304   context()->AddCapability(spv::Capability::PhysicalStorageBufferAddresses);
305   Instruction* memory_model = get_module()->GetMemoryModel();
306   memory_model->SetInOperand(
307       0u, {uint32_t(spv::AddressingModel::PhysicalStorageBuffer64)});
308 
309   context()->AddCapability(spv::Capability::Int64);
310   context()->AddCapability(spv::Capability::Linkage);
311   // Perform bindless bounds check on each entry point function in module
312   InstProcessFunction pfn =
313       [this](BasicBlock::iterator ref_inst_itr,
314              UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
315              std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
316         return GenBuffAddrCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
317                                     new_blocks);
318       };
319   InstProcessEntryPointCallTree(pfn);
320   // This pass always changes the memory model, so that linking will work
321   // properly.
322   return Status::SuccessWithChange;
323 }
324 
Process()325 Pass::Status InstBuffAddrCheckPass::Process() {
326   InitInstBuffAddrCheck();
327   return ProcessImpl();
328 }
329 
330 }  // namespace opt
331 }  // namespace spvtools
332