1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "instrument_pass.h"
18 
19 #include "source/cfa.h"
20 #include "source/spirv_constant.h"
21 
22 namespace spvtools {
23 namespace opt {
24 namespace {
25 // Indices of operands in SPIR-V instructions
26 constexpr int kEntryPointFunctionIdInIdx = 1;
27 }  // namespace
28 
MovePreludeCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,std::unique_ptr<BasicBlock> * new_blk_ptr)29 void InstrumentPass::MovePreludeCode(
30     BasicBlock::iterator ref_inst_itr,
31     UptrVectorIterator<BasicBlock> ref_block_itr,
32     std::unique_ptr<BasicBlock>* new_blk_ptr) {
33   same_block_pre_.clear();
34   same_block_post_.clear();
35   // Initialize new block. Reuse label from original block.
36   new_blk_ptr->reset(new BasicBlock(std::move(ref_block_itr->GetLabel())));
37   // Move contents of original ref block up to ref instruction.
38   for (auto cii = ref_block_itr->begin(); cii != ref_inst_itr;
39        cii = ref_block_itr->begin()) {
40     Instruction* inst = &*cii;
41     inst->RemoveFromList();
42     std::unique_ptr<Instruction> mv_ptr(inst);
43     // Remember same-block ops for possible regeneration.
44     if (IsSameBlockOp(&*mv_ptr)) {
45       auto* sb_inst_ptr = mv_ptr.get();
46       same_block_pre_[mv_ptr->result_id()] = sb_inst_ptr;
47     }
48     (*new_blk_ptr)->AddInstruction(std::move(mv_ptr));
49   }
50 }
51 
MovePostludeCode(UptrVectorIterator<BasicBlock> ref_block_itr,BasicBlock * new_blk_ptr)52 void InstrumentPass::MovePostludeCode(
53     UptrVectorIterator<BasicBlock> ref_block_itr, BasicBlock* new_blk_ptr) {
54   // Move contents of original ref block.
55   for (auto cii = ref_block_itr->begin(); cii != ref_block_itr->end();
56        cii = ref_block_itr->begin()) {
57     Instruction* inst = &*cii;
58     inst->RemoveFromList();
59     std::unique_ptr<Instruction> mv_inst(inst);
60     // Regenerate any same-block instruction that has not been seen in the
61     // current block.
62     if (same_block_pre_.size() > 0) {
63       CloneSameBlockOps(&mv_inst, &same_block_post_, &same_block_pre_,
64                         new_blk_ptr);
65       // Remember same-block ops in this block.
66       if (IsSameBlockOp(&*mv_inst)) {
67         const uint32_t rid = mv_inst->result_id();
68         same_block_post_[rid] = rid;
69       }
70     }
71     new_blk_ptr->AddInstruction(std::move(mv_inst));
72   }
73 }
74 
NewLabel(uint32_t label_id)75 std::unique_ptr<Instruction> InstrumentPass::NewLabel(uint32_t label_id) {
76   auto new_label =
77       MakeUnique<Instruction>(context(), spv::Op::OpLabel, 0, label_id,
78                               std::initializer_list<Operand>{});
79   get_def_use_mgr()->AnalyzeInstDefUse(&*new_label);
80   return new_label;
81 }
82 
StartFunction(uint32_t func_id,const analysis::Type * return_type,const std::vector<const analysis::Type * > & param_types)83 std::unique_ptr<Function> InstrumentPass::StartFunction(
84     uint32_t func_id, const analysis::Type* return_type,
85     const std::vector<const analysis::Type*>& param_types) {
86   analysis::TypeManager* type_mgr = context()->get_type_mgr();
87   analysis::Function* func_type = GetFunction(return_type, param_types);
88 
89   const std::vector<Operand> operands{
90       {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
91        {uint32_t(spv::FunctionControlMask::MaskNone)}},
92       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_mgr->GetId(func_type)}},
93   };
94   auto func_inst =
95       MakeUnique<Instruction>(context(), spv::Op::OpFunction,
96                               type_mgr->GetId(return_type), func_id, operands);
97   get_def_use_mgr()->AnalyzeInstDefUse(&*func_inst);
98   return MakeUnique<Function>(std::move(func_inst));
99 }
100 
EndFunction()101 std::unique_ptr<Instruction> InstrumentPass::EndFunction() {
102   auto end = MakeUnique<Instruction>(context(), spv::Op::OpFunctionEnd, 0, 0,
103                                      std::initializer_list<Operand>{});
104   get_def_use_mgr()->AnalyzeInstDefUse(end.get());
105   return end;
106 }
107 
AddParameters(Function & func,const std::vector<const analysis::Type * > & param_types)108 std::vector<uint32_t> InstrumentPass::AddParameters(
109     Function& func, const std::vector<const analysis::Type*>& param_types) {
110   std::vector<uint32_t> param_ids;
111   param_ids.reserve(param_types.size());
112   for (const analysis::Type* param : param_types) {
113     uint32_t pid = TakeNextId();
114     param_ids.push_back(pid);
115     auto param_inst =
116         MakeUnique<Instruction>(context(), spv::Op::OpFunctionParameter,
117                                 context()->get_type_mgr()->GetId(param), pid,
118                                 std::initializer_list<Operand>{});
119     get_def_use_mgr()->AnalyzeInstDefUse(param_inst.get());
120     func.AddParameter(std::move(param_inst));
121   }
122   return param_ids;
123 }
124 
NewName(uint32_t id,const std::string & name_str)125 std::unique_ptr<Instruction> InstrumentPass::NewName(
126     uint32_t id, const std::string& name_str) {
127   return MakeUnique<Instruction>(
128       context(), spv::Op::OpName, 0, 0,
129       std::initializer_list<Operand>{
130           {SPV_OPERAND_TYPE_ID, {id}},
131           {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}});
132 }
133 
Gen32BitCvtCode(uint32_t val_id,InstructionBuilder * builder)134 uint32_t InstrumentPass::Gen32BitCvtCode(uint32_t val_id,
135                                          InstructionBuilder* builder) {
136   // Convert integer value to 32-bit if necessary
137   analysis::TypeManager* type_mgr = context()->get_type_mgr();
138   uint32_t val_ty_id = get_def_use_mgr()->GetDef(val_id)->type_id();
139   analysis::Integer* val_ty = type_mgr->GetType(val_ty_id)->AsInteger();
140   if (val_ty->width() == 32) return val_id;
141   bool is_signed = val_ty->IsSigned();
142   analysis::Integer val_32b_ty(32, is_signed);
143   analysis::Type* val_32b_reg_ty = type_mgr->GetRegisteredType(&val_32b_ty);
144   uint32_t val_32b_reg_ty_id = type_mgr->GetId(val_32b_reg_ty);
145   if (is_signed)
146     return builder->AddUnaryOp(val_32b_reg_ty_id, spv::Op::OpSConvert, val_id)
147         ->result_id();
148   else
149     return builder->AddUnaryOp(val_32b_reg_ty_id, spv::Op::OpUConvert, val_id)
150         ->result_id();
151 }
152 
GenUintCastCode(uint32_t val_id,InstructionBuilder * builder)153 uint32_t InstrumentPass::GenUintCastCode(uint32_t val_id,
154                                          InstructionBuilder* builder) {
155   // Convert value to 32-bit if necessary
156   uint32_t val_32b_id = Gen32BitCvtCode(val_id, builder);
157   // Cast value to unsigned if necessary
158   analysis::TypeManager* type_mgr = context()->get_type_mgr();
159   uint32_t val_ty_id = get_def_use_mgr()->GetDef(val_32b_id)->type_id();
160   analysis::Integer* val_ty = type_mgr->GetType(val_ty_id)->AsInteger();
161   if (!val_ty->IsSigned()) return val_32b_id;
162   return builder->AddUnaryOp(GetUintId(), spv::Op::OpBitcast, val_32b_id)
163       ->result_id();
164 }
165 
GenVarLoad(uint32_t var_id,InstructionBuilder * builder)166 uint32_t InstrumentPass::GenVarLoad(uint32_t var_id,
167                                     InstructionBuilder* builder) {
168   Instruction* var_inst = get_def_use_mgr()->GetDef(var_id);
169   uint32_t type_id = GetPointeeTypeId(var_inst);
170   Instruction* load_inst = builder->AddLoad(type_id, var_id);
171   return load_inst->result_id();
172 }
173 
GenStageInfo(uint32_t stage_idx,InstructionBuilder * builder)174 uint32_t InstrumentPass::GenStageInfo(uint32_t stage_idx,
175                                       InstructionBuilder* builder) {
176   std::vector<uint32_t> ids(4, builder->GetUintConstantId(0));
177   ids[0] = builder->GetUintConstantId(stage_idx);
178   // %289 = OpCompositeConstruct %v4uint %uint_0 %285 %288 %uint_0
179   // TODO(greg-lunarg): Add support for all stages
180   switch (spv::ExecutionModel(stage_idx)) {
181     case spv::ExecutionModel::Vertex: {
182       // Load and store VertexId and InstanceId
183       uint32_t load_id = GenVarLoad(
184           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::VertexIndex)),
185           builder);
186       ids[1] = GenUintCastCode(load_id, builder);
187 
188       load_id = GenVarLoad(context()->GetBuiltinInputVarId(
189                                uint32_t(spv::BuiltIn::InstanceIndex)),
190                            builder);
191       ids[2] = GenUintCastCode(load_id, builder);
192     } break;
193     case spv::ExecutionModel::GLCompute:
194     case spv::ExecutionModel::TaskNV:
195     case spv::ExecutionModel::MeshNV:
196     case spv::ExecutionModel::TaskEXT:
197     case spv::ExecutionModel::MeshEXT: {
198       // Load and store GlobalInvocationId.
199       uint32_t load_id = GenVarLoad(context()->GetBuiltinInputVarId(uint32_t(
200                                         spv::BuiltIn::GlobalInvocationId)),
201                                     builder);
202       for (uint32_t u = 0; u < 3u; ++u) {
203         ids[u + 1] = builder->AddCompositeExtract(GetUintId(), load_id, {u})
204                          ->result_id();
205       }
206     } break;
207     case spv::ExecutionModel::Geometry: {
208       // Load and store PrimitiveId and InvocationId.
209       uint32_t load_id = GenVarLoad(
210           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
211           builder);
212       ids[1] = load_id;
213       load_id = GenVarLoad(
214           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::InvocationId)),
215           builder);
216       ids[2] = GenUintCastCode(load_id, builder);
217     } break;
218     case spv::ExecutionModel::TessellationControl: {
219       // Load and store InvocationId and PrimitiveId
220       uint32_t load_id = GenVarLoad(
221           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::InvocationId)),
222           builder);
223       ids[1] = GenUintCastCode(load_id, builder);
224       load_id = GenVarLoad(
225           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
226           builder);
227       ids[2] = load_id;
228     } break;
229     case spv::ExecutionModel::TessellationEvaluation: {
230       // Load and store PrimitiveId and TessCoord.uv
231       uint32_t load_id = GenVarLoad(
232           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::PrimitiveId)),
233           builder);
234       ids[1] = load_id;
235       load_id = GenVarLoad(
236           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::TessCoord)),
237           builder);
238       Instruction* uvec3_cast_inst =
239           builder->AddUnaryOp(GetVec3UintId(), spv::Op::OpBitcast, load_id);
240       uint32_t uvec3_cast_id = uvec3_cast_inst->result_id();
241       for (uint32_t u = 0; u < 2u; ++u) {
242         ids[u + 2] =
243             builder->AddCompositeExtract(GetUintId(), uvec3_cast_id, {u})
244                 ->result_id();
245       }
246     } break;
247     case spv::ExecutionModel::Fragment: {
248       // Load FragCoord and convert to Uint
249       Instruction* frag_coord_inst = builder->AddLoad(
250           GetVec4FloatId(),
251           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::FragCoord)));
252       Instruction* uint_frag_coord_inst = builder->AddUnaryOp(
253           GetVec4UintId(), spv::Op::OpBitcast, frag_coord_inst->result_id());
254       for (uint32_t u = 0; u < 2u; ++u) {
255         ids[u + 1] =
256             builder
257                 ->AddCompositeExtract(GetUintId(),
258                                       uint_frag_coord_inst->result_id(), {u})
259                 ->result_id();
260       }
261     } break;
262     case spv::ExecutionModel::RayGenerationNV:
263     case spv::ExecutionModel::IntersectionNV:
264     case spv::ExecutionModel::AnyHitNV:
265     case spv::ExecutionModel::ClosestHitNV:
266     case spv::ExecutionModel::MissNV:
267     case spv::ExecutionModel::CallableNV: {
268       // Load and store LaunchIdNV.
269       uint32_t launch_id = GenVarLoad(
270           context()->GetBuiltinInputVarId(uint32_t(spv::BuiltIn::LaunchIdNV)),
271           builder);
272       for (uint32_t u = 0; u < 3u; ++u) {
273         ids[u + 1] = builder->AddCompositeExtract(GetUintId(), launch_id, {u})
274                          ->result_id();
275       }
276     } break;
277     default: { assert(false && "unsupported stage"); } break;
278   }
279   return builder->AddCompositeConstruct(GetVec4UintId(), ids)->result_id();
280 }
281 
AllConstant(const std::vector<uint32_t> & ids)282 bool InstrumentPass::AllConstant(const std::vector<uint32_t>& ids) {
283   for (auto& id : ids) {
284     Instruction* id_inst = context()->get_def_use_mgr()->GetDef(id);
285     if (!spvOpcodeIsConstant(id_inst->opcode())) return false;
286   }
287   return true;
288 }
289 
GenReadFunctionCall(uint32_t return_id,uint32_t func_id,const std::vector<uint32_t> & func_call_args,InstructionBuilder * ref_builder)290 uint32_t InstrumentPass::GenReadFunctionCall(
291     uint32_t return_id, uint32_t func_id,
292     const std::vector<uint32_t>& func_call_args,
293     InstructionBuilder* ref_builder) {
294   // If optimizing direct reads and the call has already been generated,
295   // use its result
296   if (opt_direct_reads_) {
297     uint32_t res_id = call2id_[func_call_args];
298     if (res_id != 0) return res_id;
299   }
300   // If the function arguments are all constants, the call can be moved to the
301   // first block of the function where its result can be reused. One example
302   // where this is profitable is for uniform buffer references, of which there
303   // are often many.
304   InstructionBuilder builder(ref_builder->GetContext(),
305                              &*ref_builder->GetInsertPoint(),
306                              ref_builder->GetPreservedAnalysis());
307   bool insert_in_first_block = opt_direct_reads_ && AllConstant(func_call_args);
308   if (insert_in_first_block) {
309     Instruction* insert_before = &*curr_func_->begin()->tail();
310     builder.SetInsertPoint(insert_before);
311   }
312   uint32_t res_id =
313       builder.AddFunctionCall(return_id, func_id, func_call_args)->result_id();
314   if (insert_in_first_block) call2id_[func_call_args] = res_id;
315   return res_id;
316 }
317 
IsSameBlockOp(const Instruction * inst) const318 bool InstrumentPass::IsSameBlockOp(const Instruction* inst) const {
319   return inst->opcode() == spv::Op::OpSampledImage ||
320          inst->opcode() == spv::Op::OpImage;
321 }
322 
CloneSameBlockOps(std::unique_ptr<Instruction> * inst,std::unordered_map<uint32_t,uint32_t> * same_blk_post,std::unordered_map<uint32_t,Instruction * > * same_blk_pre,BasicBlock * block_ptr)323 void InstrumentPass::CloneSameBlockOps(
324     std::unique_ptr<Instruction>* inst,
325     std::unordered_map<uint32_t, uint32_t>* same_blk_post,
326     std::unordered_map<uint32_t, Instruction*>* same_blk_pre,
327     BasicBlock* block_ptr) {
328   bool changed = false;
329   (*inst)->ForEachInId([&same_blk_post, &same_blk_pre, &block_ptr, &changed,
330                         this](uint32_t* iid) {
331     const auto map_itr = (*same_blk_post).find(*iid);
332     if (map_itr == (*same_blk_post).end()) {
333       const auto map_itr2 = (*same_blk_pre).find(*iid);
334       if (map_itr2 != (*same_blk_pre).end()) {
335         // Clone pre-call same-block ops, map result id.
336         const Instruction* in_inst = map_itr2->second;
337         std::unique_ptr<Instruction> sb_inst(in_inst->Clone(context()));
338         const uint32_t rid = sb_inst->result_id();
339         const uint32_t nid = this->TakeNextId();
340         get_decoration_mgr()->CloneDecorations(rid, nid);
341         sb_inst->SetResultId(nid);
342         get_def_use_mgr()->AnalyzeInstDefUse(&*sb_inst);
343         (*same_blk_post)[rid] = nid;
344         *iid = nid;
345         changed = true;
346         CloneSameBlockOps(&sb_inst, same_blk_post, same_blk_pre, block_ptr);
347         block_ptr->AddInstruction(std::move(sb_inst));
348       }
349     } else {
350       // Reset same-block op operand if necessary
351       if (*iid != map_itr->second) {
352         *iid = map_itr->second;
353         changed = true;
354       }
355     }
356   });
357   if (changed) get_def_use_mgr()->AnalyzeInstUse(&**inst);
358 }
359 
UpdateSucceedingPhis(std::vector<std::unique_ptr<BasicBlock>> & new_blocks)360 void InstrumentPass::UpdateSucceedingPhis(
361     std::vector<std::unique_ptr<BasicBlock>>& new_blocks) {
362   const auto first_blk = new_blocks.begin();
363   const auto last_blk = new_blocks.end() - 1;
364   const uint32_t first_id = (*first_blk)->id();
365   const uint32_t last_id = (*last_blk)->id();
366   const BasicBlock& const_last_block = *last_blk->get();
367   const_last_block.ForEachSuccessorLabel(
368       [&first_id, &last_id, this](const uint32_t succ) {
369         BasicBlock* sbp = this->id2block_[succ];
370         sbp->ForEachPhiInst([&first_id, &last_id, this](Instruction* phi) {
371           bool changed = false;
372           phi->ForEachInId([&first_id, &last_id, &changed](uint32_t* id) {
373             if (*id == first_id) {
374               *id = last_id;
375               changed = true;
376             }
377           });
378           if (changed) get_def_use_mgr()->AnalyzeInstUse(phi);
379         });
380       });
381 }
382 
GetInteger(uint32_t width,bool is_signed)383 analysis::Integer* InstrumentPass::GetInteger(uint32_t width, bool is_signed) {
384   analysis::Integer i(width, is_signed);
385   analysis::Type* type = context()->get_type_mgr()->GetRegisteredType(&i);
386   assert(type && type->AsInteger());
387   return type->AsInteger();
388 }
389 
GetStruct(const std::vector<const analysis::Type * > & fields)390 analysis::Struct* InstrumentPass::GetStruct(
391     const std::vector<const analysis::Type*>& fields) {
392   analysis::Struct s(fields);
393   analysis::Type* type = context()->get_type_mgr()->GetRegisteredType(&s);
394   assert(type && type->AsStruct());
395   return type->AsStruct();
396 }
397 
GetRuntimeArray(const analysis::Type * element)398 analysis::RuntimeArray* InstrumentPass::GetRuntimeArray(
399     const analysis::Type* element) {
400   analysis::RuntimeArray r(element);
401   analysis::Type* type = context()->get_type_mgr()->GetRegisteredType(&r);
402   assert(type && type->AsRuntimeArray());
403   return type->AsRuntimeArray();
404 }
405 
GetArray(const analysis::Type * element,uint32_t length)406 analysis::Array* InstrumentPass::GetArray(const analysis::Type* element,
407                                           uint32_t length) {
408   uint32_t length_id = context()->get_constant_mgr()->GetUIntConstId(length);
409   analysis::Array::LengthInfo length_info{
410       length_id, {analysis::Array::LengthInfo::Case::kConstant, length}};
411 
412   analysis::Array r(element, length_info);
413 
414   analysis::Type* type = context()->get_type_mgr()->GetRegisteredType(&r);
415   assert(type && type->AsArray());
416   return type->AsArray();
417 }
418 
GetFunction(const analysis::Type * return_val,const std::vector<const analysis::Type * > & args)419 analysis::Function* InstrumentPass::GetFunction(
420     const analysis::Type* return_val,
421     const std::vector<const analysis::Type*>& args) {
422   analysis::Function func(return_val, args);
423   analysis::Type* type = context()->get_type_mgr()->GetRegisteredType(&func);
424   assert(type && type->AsFunction());
425   return type->AsFunction();
426 }
427 
GetUintXRuntimeArrayType(uint32_t width,analysis::RuntimeArray ** rarr_ty)428 analysis::RuntimeArray* InstrumentPass::GetUintXRuntimeArrayType(
429     uint32_t width, analysis::RuntimeArray** rarr_ty) {
430   if (*rarr_ty == nullptr) {
431     *rarr_ty = GetRuntimeArray(GetInteger(width, false));
432     uint32_t uint_arr_ty_id =
433         context()->get_type_mgr()->GetTypeInstruction(*rarr_ty);
434     // By the Vulkan spec, a pre-existing RuntimeArray of uint must be part of
435     // a block, and will therefore be decorated with an ArrayStride. Therefore
436     // the undecorated type returned here will not be pre-existing and can
437     // safely be decorated. Since this type is now decorated, it is out of
438     // sync with the TypeManager and therefore the TypeManager must be
439     // invalidated after this pass.
440     assert(get_def_use_mgr()->NumUses(uint_arr_ty_id) == 0 &&
441            "used RuntimeArray type returned");
442     get_decoration_mgr()->AddDecorationVal(
443         uint_arr_ty_id, uint32_t(spv::Decoration::ArrayStride), width / 8u);
444   }
445   return *rarr_ty;
446 }
447 
GetUintRuntimeArrayType(uint32_t width)448 analysis::RuntimeArray* InstrumentPass::GetUintRuntimeArrayType(
449     uint32_t width) {
450   analysis::RuntimeArray** rarr_ty =
451       (width == 64) ? &uint64_rarr_ty_ : &uint32_rarr_ty_;
452   return GetUintXRuntimeArrayType(width, rarr_ty);
453 }
454 
AddStorageBufferExt()455 void InstrumentPass::AddStorageBufferExt() {
456   if (storage_buffer_ext_defined_) return;
457   if (!get_feature_mgr()->HasExtension(kSPV_KHR_storage_buffer_storage_class)) {
458     context()->AddExtension("SPV_KHR_storage_buffer_storage_class");
459   }
460   storage_buffer_ext_defined_ = true;
461 }
462 
GetFloatId()463 uint32_t InstrumentPass::GetFloatId() {
464   if (float_id_ == 0) {
465     analysis::TypeManager* type_mgr = context()->get_type_mgr();
466     analysis::Float float_ty(32);
467     analysis::Type* reg_float_ty = type_mgr->GetRegisteredType(&float_ty);
468     float_id_ = type_mgr->GetTypeInstruction(reg_float_ty);
469   }
470   return float_id_;
471 }
472 
GetVec4FloatId()473 uint32_t InstrumentPass::GetVec4FloatId() {
474   if (v4float_id_ == 0) {
475     analysis::TypeManager* type_mgr = context()->get_type_mgr();
476     analysis::Float float_ty(32);
477     analysis::Type* reg_float_ty = type_mgr->GetRegisteredType(&float_ty);
478     analysis::Vector v4float_ty(reg_float_ty, 4);
479     analysis::Type* reg_v4float_ty = type_mgr->GetRegisteredType(&v4float_ty);
480     v4float_id_ = type_mgr->GetTypeInstruction(reg_v4float_ty);
481   }
482   return v4float_id_;
483 }
484 
GetUintId()485 uint32_t InstrumentPass::GetUintId() {
486   if (uint_id_ == 0) {
487     analysis::TypeManager* type_mgr = context()->get_type_mgr();
488     analysis::Integer uint_ty(32, false);
489     analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty);
490     uint_id_ = type_mgr->GetTypeInstruction(reg_uint_ty);
491   }
492   return uint_id_;
493 }
494 
GetUint64Id()495 uint32_t InstrumentPass::GetUint64Id() {
496   if (uint64_id_ == 0) {
497     analysis::TypeManager* type_mgr = context()->get_type_mgr();
498     analysis::Integer uint64_ty(64, false);
499     analysis::Type* reg_uint64_ty = type_mgr->GetRegisteredType(&uint64_ty);
500     uint64_id_ = type_mgr->GetTypeInstruction(reg_uint64_ty);
501   }
502   return uint64_id_;
503 }
504 
GetUint8Id()505 uint32_t InstrumentPass::GetUint8Id() {
506   if (uint8_id_ == 0) {
507     analysis::TypeManager* type_mgr = context()->get_type_mgr();
508     analysis::Integer uint8_ty(8, false);
509     analysis::Type* reg_uint8_ty = type_mgr->GetRegisteredType(&uint8_ty);
510     uint8_id_ = type_mgr->GetTypeInstruction(reg_uint8_ty);
511   }
512   return uint8_id_;
513 }
514 
GetVecUintId(uint32_t len)515 uint32_t InstrumentPass::GetVecUintId(uint32_t len) {
516   analysis::TypeManager* type_mgr = context()->get_type_mgr();
517   analysis::Integer uint_ty(32, false);
518   analysis::Type* reg_uint_ty = type_mgr->GetRegisteredType(&uint_ty);
519   analysis::Vector v_uint_ty(reg_uint_ty, len);
520   analysis::Type* reg_v_uint_ty = type_mgr->GetRegisteredType(&v_uint_ty);
521   uint32_t v_uint_id = type_mgr->GetTypeInstruction(reg_v_uint_ty);
522   return v_uint_id;
523 }
524 
GetVec4UintId()525 uint32_t InstrumentPass::GetVec4UintId() {
526   if (v4uint_id_ == 0) v4uint_id_ = GetVecUintId(4u);
527   return v4uint_id_;
528 }
529 
GetVec3UintId()530 uint32_t InstrumentPass::GetVec3UintId() {
531   if (v3uint_id_ == 0) v3uint_id_ = GetVecUintId(3u);
532   return v3uint_id_;
533 }
534 
GetBoolId()535 uint32_t InstrumentPass::GetBoolId() {
536   if (bool_id_ == 0) {
537     analysis::TypeManager* type_mgr = context()->get_type_mgr();
538     analysis::Bool bool_ty;
539     analysis::Type* reg_bool_ty = type_mgr->GetRegisteredType(&bool_ty);
540     bool_id_ = type_mgr->GetTypeInstruction(reg_bool_ty);
541   }
542   return bool_id_;
543 }
544 
GetVoidId()545 uint32_t InstrumentPass::GetVoidId() {
546   if (void_id_ == 0) {
547     analysis::TypeManager* type_mgr = context()->get_type_mgr();
548     analysis::Void void_ty;
549     analysis::Type* reg_void_ty = type_mgr->GetRegisteredType(&void_ty);
550     void_id_ = type_mgr->GetTypeInstruction(reg_void_ty);
551   }
552   return void_id_;
553 }
554 
SplitBlock(BasicBlock::iterator inst_itr,UptrVectorIterator<BasicBlock> block_itr,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)555 void InstrumentPass::SplitBlock(
556     BasicBlock::iterator inst_itr, UptrVectorIterator<BasicBlock> block_itr,
557     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
558   // Make sure def/use analysis is done before we start moving instructions
559   // out of function
560   (void)get_def_use_mgr();
561   // Move original block's preceding instructions into first new block
562   std::unique_ptr<BasicBlock> first_blk_ptr;
563   MovePreludeCode(inst_itr, block_itr, &first_blk_ptr);
564   InstructionBuilder builder(
565       context(), &*first_blk_ptr,
566       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
567   uint32_t split_blk_id = TakeNextId();
568   std::unique_ptr<Instruction> split_label(NewLabel(split_blk_id));
569   (void)builder.AddBranch(split_blk_id);
570   new_blocks->push_back(std::move(first_blk_ptr));
571   // Move remaining instructions into split block and add to new blocks
572   std::unique_ptr<BasicBlock> split_blk_ptr(
573       new BasicBlock(std::move(split_label)));
574   MovePostludeCode(block_itr, &*split_blk_ptr);
575   new_blocks->push_back(std::move(split_blk_ptr));
576 }
577 
InstrumentFunction(Function * func,uint32_t stage_idx,InstProcessFunction & pfn)578 bool InstrumentPass::InstrumentFunction(Function* func, uint32_t stage_idx,
579                                         InstProcessFunction& pfn) {
580   curr_func_ = func;
581   call2id_.clear();
582   bool first_block_split = false;
583   bool modified = false;
584   // Apply instrumentation function to each instruction.
585   // Using block iterators here because of block erasures and insertions.
586   std::vector<std::unique_ptr<BasicBlock>> new_blks;
587   for (auto bi = func->begin(); bi != func->end(); ++bi) {
588     for (auto ii = bi->begin(); ii != bi->end();) {
589       // Split all executable instructions out of first block into a following
590       // block. This will allow function calls to be inserted into the first
591       // block without interfering with the instrumentation algorithm.
592       if (opt_direct_reads_ && !first_block_split) {
593         if (ii->opcode() != spv::Op::OpVariable) {
594           SplitBlock(ii, bi, &new_blks);
595           first_block_split = true;
596         }
597       } else {
598         pfn(ii, bi, stage_idx, &new_blks);
599       }
600       // If no new code, continue
601       if (new_blks.size() == 0) {
602         ++ii;
603         continue;
604       }
605       // Add new blocks to label id map
606       for (auto& blk : new_blks) id2block_[blk->id()] = &*blk;
607       // If there are new blocks we know there will always be two or
608       // more, so update succeeding phis with label of new last block.
609       size_t newBlocksSize = new_blks.size();
610       assert(newBlocksSize > 1);
611       UpdateSucceedingPhis(new_blks);
612       // Replace original block with new block(s)
613       bi = bi.Erase();
614       for (auto& bb : new_blks) {
615         bb->SetParent(func);
616       }
617       bi = bi.InsertBefore(&new_blks);
618       // Reset block iterator to last new block
619       for (size_t i = 0; i < newBlocksSize - 1; i++) ++bi;
620       modified = true;
621       // Restart instrumenting at beginning of last new block,
622       // but skip over any new phi or copy instruction.
623       ii = bi->begin();
624       if (ii->opcode() == spv::Op::OpPhi ||
625           ii->opcode() == spv::Op::OpCopyObject)
626         ++ii;
627       new_blks.clear();
628     }
629   }
630   return modified;
631 }
632 
InstProcessCallTreeFromRoots(InstProcessFunction & pfn,std::queue<uint32_t> * roots,uint32_t stage_idx)633 bool InstrumentPass::InstProcessCallTreeFromRoots(InstProcessFunction& pfn,
634                                                   std::queue<uint32_t>* roots,
635                                                   uint32_t stage_idx) {
636   bool modified = false;
637   std::unordered_set<uint32_t> done;
638   // Don't process input and output functions
639   for (auto& ifn : param2input_func_id_) done.insert(ifn.second);
640   for (auto& ofn : param2output_func_id_) done.insert(ofn.second);
641   // Process all functions from roots
642   while (!roots->empty()) {
643     const uint32_t fi = roots->front();
644     roots->pop();
645     if (done.insert(fi).second) {
646       Function* fn = id2function_.at(fi);
647       // Add calls first so we don't add new output function
648       context()->AddCalls(fn, roots);
649       modified = InstrumentFunction(fn, stage_idx, pfn) || modified;
650     }
651   }
652   return modified;
653 }
654 
InstProcessEntryPointCallTree(InstProcessFunction & pfn)655 bool InstrumentPass::InstProcessEntryPointCallTree(InstProcessFunction& pfn) {
656   uint32_t stage_id;
657   if (use_stage_info_) {
658     // Make sure all entry points have the same execution model. Do not
659     // instrument if they do not.
660     // TODO(greg-lunarg): Handle mixed stages. Technically, a shader module
661     // can contain entry points with different execution models, although
662     // such modules will likely be rare as GLSL and HLSL are geared toward
663     // one model per module. In such cases we will need
664     // to clone any functions which are in the call trees of entrypoints
665     // with differing execution models.
666     spv::ExecutionModel stage = context()->GetStage();
667     // Check for supported stages
668     if (stage != spv::ExecutionModel::Vertex &&
669         stage != spv::ExecutionModel::Fragment &&
670         stage != spv::ExecutionModel::Geometry &&
671         stage != spv::ExecutionModel::GLCompute &&
672         stage != spv::ExecutionModel::TessellationControl &&
673         stage != spv::ExecutionModel::TessellationEvaluation &&
674         stage != spv::ExecutionModel::TaskNV &&
675         stage != spv::ExecutionModel::MeshNV &&
676         stage != spv::ExecutionModel::RayGenerationNV &&
677         stage != spv::ExecutionModel::IntersectionNV &&
678         stage != spv::ExecutionModel::AnyHitNV &&
679         stage != spv::ExecutionModel::ClosestHitNV &&
680         stage != spv::ExecutionModel::MissNV &&
681         stage != spv::ExecutionModel::CallableNV &&
682         stage != spv::ExecutionModel::TaskEXT &&
683         stage != spv::ExecutionModel::MeshEXT) {
684       if (consumer()) {
685         std::string message = "Stage not supported by instrumentation";
686         consumer()(SPV_MSG_ERROR, 0, {0, 0, 0}, message.c_str());
687       }
688       return false;
689     }
690     stage_id = static_cast<uint32_t>(stage);
691   } else {
692     stage_id = 0;
693   }
694   // Add together the roots of all entry points
695   std::queue<uint32_t> roots;
696   for (auto& e : get_module()->entry_points()) {
697     roots.push(e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx));
698   }
699   bool modified = InstProcessCallTreeFromRoots(pfn, &roots, stage_id);
700   return modified;
701 }
702 
InitializeInstrument()703 void InstrumentPass::InitializeInstrument() {
704   float_id_ = 0;
705   v4float_id_ = 0;
706   uint_id_ = 0;
707   uint64_id_ = 0;
708   uint8_id_ = 0;
709   v4uint_id_ = 0;
710   v3uint_id_ = 0;
711   bool_id_ = 0;
712   void_id_ = 0;
713   storage_buffer_ext_defined_ = false;
714   uint32_rarr_ty_ = nullptr;
715   uint64_rarr_ty_ = nullptr;
716 
717   // clear collections
718   id2function_.clear();
719   id2block_.clear();
720 
721   // clear maps
722   param2input_func_id_.clear();
723   param2output_func_id_.clear();
724 
725   // Initialize function and block maps.
726   for (auto& fn : *get_module()) {
727     id2function_[fn.result_id()] = &fn;
728     for (auto& blk : fn) {
729       id2block_[blk.id()] = &blk;
730     }
731   }
732 
733   // Remember original instruction offsets
734   uint32_t module_offset = 0;
735   Module* module = get_module();
736   for (auto& i : context()->capabilities()) {
737     (void)i;
738     ++module_offset;
739   }
740   for (auto& i : module->extensions()) {
741     (void)i;
742     ++module_offset;
743   }
744   for (auto& i : module->ext_inst_imports()) {
745     (void)i;
746     ++module_offset;
747   }
748   ++module_offset;  // memory_model
749   for (auto& i : module->entry_points()) {
750     (void)i;
751     ++module_offset;
752   }
753   for (auto& i : module->execution_modes()) {
754     (void)i;
755     ++module_offset;
756   }
757   for (auto& i : module->debugs1()) {
758     (void)i;
759     ++module_offset;
760   }
761   for (auto& i : module->debugs2()) {
762     (void)i;
763     ++module_offset;
764   }
765   for (auto& i : module->debugs3()) {
766     (void)i;
767     ++module_offset;
768   }
769   for (auto& i : module->ext_inst_debuginfo()) {
770     (void)i;
771     ++module_offset;
772   }
773   for (auto& i : module->annotations()) {
774     (void)i;
775     ++module_offset;
776   }
777   for (auto& i : module->types_values()) {
778     module_offset += 1;
779     module_offset += static_cast<uint32_t>(i.dbg_line_insts().size());
780   }
781 
782   auto curr_fn = get_module()->begin();
783   for (; curr_fn != get_module()->end(); ++curr_fn) {
784     // Count function instruction
785     module_offset += 1;
786     curr_fn->ForEachParam(
787         [&module_offset](const Instruction*) { module_offset += 1; }, true);
788     for (auto& blk : *curr_fn) {
789       // Count label
790       module_offset += 1;
791       for (auto& inst : blk) {
792         module_offset += static_cast<uint32_t>(inst.dbg_line_insts().size());
793         uid2offset_[inst.unique_id()] = module_offset;
794         module_offset += 1;
795       }
796     }
797     // Count function end instruction
798     module_offset += 1;
799   }
800 }
801 
802 }  // namespace opt
803 }  // namespace spvtools
804