• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/fuzzer_util.h"
16 
17 #include <algorithm>
18 #include <unordered_set>
19 
20 #include "source/opt/build_module.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 
25 namespace fuzzerutil {
26 namespace {
27 
28 // A utility class that uses RAII to change and restore the terminator
29 // instruction of the |block|.
30 class ChangeTerminatorRAII {
31  public:
ChangeTerminatorRAII(opt::BasicBlock * block,opt::Instruction new_terminator)32   explicit ChangeTerminatorRAII(opt::BasicBlock* block,
33                                 opt::Instruction new_terminator)
34       : block_(block), old_terminator_(std::move(*block->terminator())) {
35     *block_->terminator() = std::move(new_terminator);
36   }
37 
~ChangeTerminatorRAII()38   ~ChangeTerminatorRAII() {
39     *block_->terminator() = std::move(old_terminator_);
40   }
41 
42  private:
43   opt::BasicBlock* block_;
44   opt::Instruction old_terminator_;
45 };
46 
MaybeGetOpConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t type_id,bool is_irrelevant)47 uint32_t MaybeGetOpConstant(opt::IRContext* ir_context,
48                             const TransformationContext& transformation_context,
49                             const std::vector<uint32_t>& words,
50                             uint32_t type_id, bool is_irrelevant) {
51   for (const auto& inst : ir_context->types_values()) {
52     if (inst.opcode() == SpvOpConstant && inst.type_id() == type_id &&
53         inst.GetInOperand(0).words == words &&
54         transformation_context.GetFactManager()->IdIsIrrelevant(
55             inst.result_id()) == is_irrelevant) {
56       return inst.result_id();
57     }
58   }
59 
60   return 0;
61 }
62 
63 }  // namespace
64 
65 const spvtools::MessageConsumer kSilentMessageConsumer =
66     [](spv_message_level_t, const char*, const spv_position_t&,
__anon77d6a2b50202(spv_message_level_t, const char*, const spv_position_t&, const char*) 67        const char*) -> void {};
68 
BuildIRContext(spv_target_env target_env,const spvtools::MessageConsumer & message_consumer,const std::vector<uint32_t> & binary_in,spv_validator_options validator_options,std::unique_ptr<spvtools::opt::IRContext> * ir_context)69 bool BuildIRContext(spv_target_env target_env,
70                     const spvtools::MessageConsumer& message_consumer,
71                     const std::vector<uint32_t>& binary_in,
72                     spv_validator_options validator_options,
73                     std::unique_ptr<spvtools::opt::IRContext>* ir_context) {
74   SpirvTools tools(target_env);
75   tools.SetMessageConsumer(message_consumer);
76   if (!tools.IsValid()) {
77     message_consumer(SPV_MSG_ERROR, nullptr, {},
78                      "Failed to create SPIRV-Tools interface; stopping.");
79     return false;
80   }
81 
82   // Initial binary should be valid.
83   if (!tools.Validate(binary_in.data(), binary_in.size(), validator_options)) {
84     message_consumer(SPV_MSG_ERROR, nullptr, {},
85                      "Initial binary is invalid; stopping.");
86     return false;
87   }
88 
89   // Build the module from the input binary.
90   auto result = BuildModule(target_env, message_consumer, binary_in.data(),
91                             binary_in.size());
92   assert(result && "IRContext must be valid");
93   *ir_context = std::move(result);
94   return true;
95 }
96 
IsFreshId(opt::IRContext * context,uint32_t id)97 bool IsFreshId(opt::IRContext* context, uint32_t id) {
98   return !context->get_def_use_mgr()->GetDef(id);
99 }
100 
UpdateModuleIdBound(opt::IRContext * context,uint32_t id)101 void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
102   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
103   //  case where the maximum id bound is reached.
104   context->module()->SetIdBound(
105       std::max(context->module()->id_bound(), id + 1));
106 }
107 
MaybeFindBlock(opt::IRContext * context,uint32_t maybe_block_id)108 opt::BasicBlock* MaybeFindBlock(opt::IRContext* context,
109                                 uint32_t maybe_block_id) {
110   auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
111   if (inst == nullptr) {
112     // No instruction defining this id was found.
113     return nullptr;
114   }
115   if (inst->opcode() != SpvOpLabel) {
116     // The instruction defining the id is not a label, so it cannot be a block
117     // id.
118     return nullptr;
119   }
120   return context->cfg()->block(maybe_block_id);
121 }
122 
PhiIdsOkForNewEdge(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)123 bool PhiIdsOkForNewEdge(
124     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
125     const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
126   if (bb_from->IsSuccessor(bb_to)) {
127     // There is already an edge from |from_block| to |to_block|, so there is
128     // no need to extend OpPhi instructions.  Do not allow phi ids to be
129     // present. This might turn out to be too strict; perhaps it would be OK
130     // just to ignore the ids in this case.
131     return phi_ids.empty();
132   }
133   // The edge would add a previously non-existent edge from |from_block| to
134   // |to_block|, so we go through the given phi ids and check that they exactly
135   // match the OpPhi instructions in |to_block|.
136   uint32_t phi_index = 0;
137   // An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
138   // makes sense here because we need to increment |phi_index| for each OpPhi
139   // instruction.
140   for (auto& inst : *bb_to) {
141     if (inst.opcode() != SpvOpPhi) {
142       // The OpPhi instructions all occur at the start of the block; if we find
143       // a non-OpPhi then we have seen them all.
144       break;
145     }
146     if (phi_index == static_cast<uint32_t>(phi_ids.size())) {
147       // Not enough phi ids have been provided to account for the OpPhi
148       // instructions.
149       return false;
150     }
151     // Look for an instruction defining the next phi id.
152     opt::Instruction* phi_extension =
153         context->get_def_use_mgr()->GetDef(phi_ids[phi_index]);
154     if (!phi_extension) {
155       // The id given to extend this OpPhi does not exist.
156       return false;
157     }
158     if (phi_extension->type_id() != inst.type_id()) {
159       // The instruction given to extend this OpPhi either does not have a type
160       // or its type does not match that of the OpPhi.
161       return false;
162     }
163 
164     if (context->get_instr_block(phi_extension)) {
165       // The instruction defining the phi id has an associated block (i.e., it
166       // is not a global value).  Check whether its definition dominates the
167       // exit of |from_block|.
168       auto dominator_analysis =
169           context->GetDominatorAnalysis(bb_from->GetParent());
170       if (!dominator_analysis->Dominates(phi_extension,
171                                          bb_from->terminator())) {
172         // The given id is no good as its definition does not dominate the exit
173         // of |from_block|
174         return false;
175       }
176     }
177     phi_index++;
178   }
179   // We allow some of the ids provided for extending OpPhi instructions to be
180   // unused.  Their presence does no harm, and requiring a perfect match may
181   // make transformations less likely to cleanly apply.
182   return true;
183 }
184 
CreateUnreachableEdgeInstruction(opt::IRContext * ir_context,uint32_t bb_from_id,uint32_t bb_to_id,uint32_t bool_id)185 opt::Instruction CreateUnreachableEdgeInstruction(opt::IRContext* ir_context,
186                                                   uint32_t bb_from_id,
187                                                   uint32_t bb_to_id,
188                                                   uint32_t bool_id) {
189   const auto* bb_from = MaybeFindBlock(ir_context, bb_from_id);
190   assert(bb_from && "|bb_from_id| is invalid");
191   assert(MaybeFindBlock(ir_context, bb_to_id) && "|bb_to_id| is invalid");
192   assert(bb_from->terminator()->opcode() == SpvOpBranch &&
193          "Precondition on terminator of bb_from is not satisfied");
194 
195   // Get the id of the boolean constant to be used as the condition.
196   auto condition_inst = ir_context->get_def_use_mgr()->GetDef(bool_id);
197   assert(condition_inst &&
198          (condition_inst->opcode() == SpvOpConstantTrue ||
199           condition_inst->opcode() == SpvOpConstantFalse) &&
200          "|bool_id| is invalid");
201 
202   auto condition_value = condition_inst->opcode() == SpvOpConstantTrue;
203   auto successor_id = bb_from->terminator()->GetSingleWordInOperand(0);
204 
205   // Add the dead branch, by turning OpBranch into OpBranchConditional, and
206   // ordering the targets depending on whether the given boolean corresponds to
207   // true or false.
208   return opt::Instruction(
209       ir_context, SpvOpBranchConditional, 0, 0,
210       {{SPV_OPERAND_TYPE_ID, {bool_id}},
211        {SPV_OPERAND_TYPE_ID, {condition_value ? successor_id : bb_to_id}},
212        {SPV_OPERAND_TYPE_ID, {condition_value ? bb_to_id : successor_id}}});
213 }
214 
AddUnreachableEdgeAndUpdateOpPhis(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,uint32_t bool_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)215 void AddUnreachableEdgeAndUpdateOpPhis(
216     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
217     uint32_t bool_id,
218     const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
219   assert(PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) &&
220          "Precondition on phi_ids is not satisfied");
221 
222   const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
223   *bb_from->terminator() = CreateUnreachableEdgeInstruction(
224       context, bb_from->id(), bb_to->id(), bool_id);
225 
226   // Update OpPhi instructions in the target block if this branch adds a
227   // previously non-existent edge from source to target.
228   if (!from_to_edge_already_exists) {
229     uint32_t phi_index = 0;
230     for (auto& inst : *bb_to) {
231       if (inst.opcode() != SpvOpPhi) {
232         break;
233       }
234       assert(phi_index < static_cast<uint32_t>(phi_ids.size()) &&
235              "There should be at least one phi id per OpPhi instruction.");
236       inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}});
237       inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}});
238       phi_index++;
239     }
240   }
241 }
242 
BlockIsBackEdge(opt::IRContext * context,uint32_t block_id,uint32_t loop_header_id)243 bool BlockIsBackEdge(opt::IRContext* context, uint32_t block_id,
244                      uint32_t loop_header_id) {
245   auto block = context->cfg()->block(block_id);
246   auto loop_header = context->cfg()->block(loop_header_id);
247 
248   // |block| and |loop_header| must be defined, |loop_header| must be in fact
249   // loop header and |block| must branch to it.
250   if (!(block && loop_header && loop_header->IsLoopHeader() &&
251         block->IsSuccessor(loop_header))) {
252     return false;
253   }
254 
255   // |block_id| must be reachable and be dominated by |loop_header|.
256   opt::DominatorAnalysis* dominator_analysis =
257       context->GetDominatorAnalysis(loop_header->GetParent());
258   return dominator_analysis->IsReachable(block_id) &&
259          dominator_analysis->Dominates(loop_header_id, block_id);
260 }
261 
BlockIsInLoopContinueConstruct(opt::IRContext * context,uint32_t block_id,uint32_t maybe_loop_header_id)262 bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
263                                     uint32_t maybe_loop_header_id) {
264   // We deem a block to be part of a loop's continue construct if the loop's
265   // continue target dominates the block.
266   auto containing_construct_block = context->cfg()->block(maybe_loop_header_id);
267   if (containing_construct_block->IsLoopHeader()) {
268     auto continue_target = containing_construct_block->ContinueBlockId();
269     if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
270             ->Dominates(continue_target, block_id)) {
271       return true;
272     }
273   }
274   return false;
275 }
276 
GetIteratorForInstruction(opt::BasicBlock * block,const opt::Instruction * inst)277 opt::BasicBlock::iterator GetIteratorForInstruction(
278     opt::BasicBlock* block, const opt::Instruction* inst) {
279   for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
280     if (inst == &*inst_it) {
281       return inst_it;
282     }
283   }
284   return block->end();
285 }
286 
BlockIsReachableInItsFunction(opt::IRContext * context,opt::BasicBlock * bb)287 bool BlockIsReachableInItsFunction(opt::IRContext* context,
288                                    opt::BasicBlock* bb) {
289   auto enclosing_function = bb->GetParent();
290   return context->GetDominatorAnalysis(enclosing_function)
291       ->Dominates(enclosing_function->entry().get(), bb);
292 }
293 
CanInsertOpcodeBeforeInstruction(SpvOp opcode,const opt::BasicBlock::iterator & instruction_in_block)294 bool CanInsertOpcodeBeforeInstruction(
295     SpvOp opcode, const opt::BasicBlock::iterator& instruction_in_block) {
296   if (instruction_in_block->PreviousNode() &&
297       (instruction_in_block->PreviousNode()->opcode() == SpvOpLoopMerge ||
298        instruction_in_block->PreviousNode()->opcode() == SpvOpSelectionMerge)) {
299     // We cannot insert directly after a merge instruction.
300     return false;
301   }
302   if (opcode != SpvOpVariable &&
303       instruction_in_block->opcode() == SpvOpVariable) {
304     // We cannot insert a non-OpVariable instruction directly before a
305     // variable; variables in a function must be contiguous in the entry block.
306     return false;
307   }
308   // We cannot insert a non-OpPhi instruction directly before an OpPhi, because
309   // OpPhi instructions need to be contiguous at the start of a block.
310   return opcode == SpvOpPhi || instruction_in_block->opcode() != SpvOpPhi;
311 }
312 
CanMakeSynonymOf(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * inst)313 bool CanMakeSynonymOf(opt::IRContext* ir_context,
314                       const TransformationContext& transformation_context,
315                       opt::Instruction* inst) {
316   if (inst->opcode() == SpvOpSampledImage) {
317     // The SPIR-V data rules say that only very specific instructions may
318     // may consume the result id of an OpSampledImage, and this excludes the
319     // instructions that are used for making synonyms.
320     return false;
321   }
322   if (!inst->HasResultId()) {
323     // We can only make a synonym of an instruction that generates an id.
324     return false;
325   }
326   if (transformation_context.GetFactManager()->IdIsIrrelevant(
327           inst->result_id())) {
328     // An irrelevant id can't be a synonym of anything.
329     return false;
330   }
331   if (!inst->type_id()) {
332     // We can only make a synonym of an instruction that has a type.
333     return false;
334   }
335   auto type_inst = ir_context->get_def_use_mgr()->GetDef(inst->type_id());
336   if (type_inst->opcode() == SpvOpTypeVoid) {
337     // We only make synonyms of instructions that define objects, and an object
338     // cannot have void type.
339     return false;
340   }
341   if (type_inst->opcode() == SpvOpTypePointer) {
342     switch (inst->opcode()) {
343       case SpvOpConstantNull:
344       case SpvOpUndef:
345         // We disallow making synonyms of null or undefined pointers.  This is
346         // to provide the property that if the original shader exhibited no bad
347         // pointer accesses, the transformed shader will not either.
348         return false;
349       default:
350         break;
351     }
352   }
353 
354   // We do not make synonyms of objects that have decorations: if the synonym is
355   // not decorated analogously, using the original object vs. its synonymous
356   // form may not be equivalent.
357   return ir_context->get_decoration_mgr()
358       ->GetDecorationsFor(inst->result_id(), true)
359       .empty();
360 }
361 
IsCompositeType(const opt::analysis::Type * type)362 bool IsCompositeType(const opt::analysis::Type* type) {
363   return type && (type->AsArray() || type->AsMatrix() || type->AsStruct() ||
364                   type->AsVector());
365 }
366 
RepeatedFieldToVector(const google::protobuf::RepeatedField<uint32_t> & repeated_field)367 std::vector<uint32_t> RepeatedFieldToVector(
368     const google::protobuf::RepeatedField<uint32_t>& repeated_field) {
369   std::vector<uint32_t> result;
370   for (auto i : repeated_field) {
371     result.push_back(i);
372   }
373   return result;
374 }
375 
WalkOneCompositeTypeIndex(opt::IRContext * context,uint32_t base_object_type_id,uint32_t index)376 uint32_t WalkOneCompositeTypeIndex(opt::IRContext* context,
377                                    uint32_t base_object_type_id,
378                                    uint32_t index) {
379   auto should_be_composite_type =
380       context->get_def_use_mgr()->GetDef(base_object_type_id);
381   assert(should_be_composite_type && "The type should exist.");
382   switch (should_be_composite_type->opcode()) {
383     case SpvOpTypeArray: {
384       auto array_length = GetArraySize(*should_be_composite_type, context);
385       if (array_length == 0 || index >= array_length) {
386         return 0;
387       }
388       return should_be_composite_type->GetSingleWordInOperand(0);
389     }
390     case SpvOpTypeMatrix:
391     case SpvOpTypeVector: {
392       auto count = should_be_composite_type->GetSingleWordInOperand(1);
393       if (index >= count) {
394         return 0;
395       }
396       return should_be_composite_type->GetSingleWordInOperand(0);
397     }
398     case SpvOpTypeStruct: {
399       if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
400         return 0;
401       }
402       return should_be_composite_type->GetSingleWordInOperand(index);
403     }
404     default:
405       return 0;
406   }
407 }
408 
WalkCompositeTypeIndices(opt::IRContext * context,uint32_t base_object_type_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & indices)409 uint32_t WalkCompositeTypeIndices(
410     opt::IRContext* context, uint32_t base_object_type_id,
411     const google::protobuf::RepeatedField<google::protobuf::uint32>& indices) {
412   uint32_t sub_object_type_id = base_object_type_id;
413   for (auto index : indices) {
414     sub_object_type_id =
415         WalkOneCompositeTypeIndex(context, sub_object_type_id, index);
416     if (!sub_object_type_id) {
417       return 0;
418     }
419   }
420   return sub_object_type_id;
421 }
422 
GetNumberOfStructMembers(const opt::Instruction & struct_type_instruction)423 uint32_t GetNumberOfStructMembers(
424     const opt::Instruction& struct_type_instruction) {
425   assert(struct_type_instruction.opcode() == SpvOpTypeStruct &&
426          "An OpTypeStruct instruction is required here.");
427   return struct_type_instruction.NumInOperands();
428 }
429 
GetArraySize(const opt::Instruction & array_type_instruction,opt::IRContext * context)430 uint32_t GetArraySize(const opt::Instruction& array_type_instruction,
431                       opt::IRContext* context) {
432   auto array_length_constant =
433       context->get_constant_mgr()
434           ->GetConstantFromInst(context->get_def_use_mgr()->GetDef(
435               array_type_instruction.GetSingleWordInOperand(1)))
436           ->AsIntConstant();
437   if (array_length_constant->words().size() != 1) {
438     return 0;
439   }
440   return array_length_constant->GetU32();
441 }
442 
GetBoundForCompositeIndex(const opt::Instruction & composite_type_inst,opt::IRContext * ir_context)443 uint32_t GetBoundForCompositeIndex(const opt::Instruction& composite_type_inst,
444                                    opt::IRContext* ir_context) {
445   switch (composite_type_inst.opcode()) {
446     case SpvOpTypeArray:
447       return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
448     case SpvOpTypeMatrix:
449     case SpvOpTypeVector:
450       return composite_type_inst.GetSingleWordInOperand(1);
451     case SpvOpTypeStruct: {
452       return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
453     }
454     case SpvOpTypeRuntimeArray:
455       assert(false &&
456              "GetBoundForCompositeIndex should not be invoked with an "
457              "OpTypeRuntimeArray, which does not have a static bound.");
458       return 0;
459     default:
460       assert(false && "Unknown composite type.");
461       return 0;
462   }
463 }
464 
IsValid(const opt::IRContext * context,spv_validator_options validator_options,MessageConsumer consumer)465 bool IsValid(const opt::IRContext* context,
466              spv_validator_options validator_options,
467              MessageConsumer consumer) {
468   std::vector<uint32_t> binary;
469   context->module()->ToBinary(&binary, false);
470   SpirvTools tools(context->grammar().target_env());
471   tools.SetMessageConsumer(std::move(consumer));
472   return tools.Validate(binary.data(), binary.size(), validator_options);
473 }
474 
IsValidAndWellFormed(const opt::IRContext * ir_context,spv_validator_options validator_options,MessageConsumer consumer)475 bool IsValidAndWellFormed(const opt::IRContext* ir_context,
476                           spv_validator_options validator_options,
477                           MessageConsumer consumer) {
478   if (!IsValid(ir_context, validator_options, consumer)) {
479     // Expression to dump |ir_context| to /data/temp/shader.spv:
480     //    DumpShader(ir_context, "/data/temp/shader.spv")
481     consumer(SPV_MSG_INFO, nullptr, {},
482              "Module is invalid (set a breakpoint to inspect).");
483     return false;
484   }
485   // Check that all blocks in the module have appropriate parent functions.
486   for (auto& function : *ir_context->module()) {
487     for (auto& block : function) {
488       if (block.GetParent() == nullptr) {
489         std::stringstream ss;
490         ss << "Block " << block.id() << " has no parent; its parent should be "
491            << function.result_id() << " (set a breakpoint to inspect).";
492         consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
493         return false;
494       }
495       if (block.GetParent() != &function) {
496         std::stringstream ss;
497         ss << "Block " << block.id() << " should have parent "
498            << function.result_id() << " but instead has parent "
499            << block.GetParent() << " (set a breakpoint to inspect).";
500         consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
501         return false;
502       }
503     }
504   }
505 
506   // Check that all instructions have distinct unique ids.  We map each unique
507   // id to the first instruction it is observed to be associated with so that
508   // if we encounter a duplicate we have access to the previous instruction -
509   // this is a useful aid to debugging.
510   std::unordered_map<uint32_t, opt::Instruction*> unique_ids;
511   bool found_duplicate = false;
512   ir_context->module()->ForEachInst([&consumer, &found_duplicate,
513                                      &unique_ids](opt::Instruction* inst) {
514     if (unique_ids.count(inst->unique_id()) != 0) {
515       consumer(SPV_MSG_INFO, nullptr, {},
516                "Two instructions have the same unique id (set a breakpoint to "
517                "inspect).");
518       found_duplicate = true;
519     }
520     unique_ids.insert({inst->unique_id(), inst});
521   });
522   return !found_duplicate;
523 }
524 
CloneIRContext(opt::IRContext * context)525 std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) {
526   std::vector<uint32_t> binary;
527   context->module()->ToBinary(&binary, false);
528   return BuildModule(context->grammar().target_env(), nullptr, binary.data(),
529                      binary.size());
530 }
531 
IsNonFunctionTypeId(opt::IRContext * ir_context,uint32_t id)532 bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id) {
533   auto type = ir_context->get_type_mgr()->GetType(id);
534   return type && !type->AsFunction();
535 }
536 
IsMergeOrContinue(opt::IRContext * ir_context,uint32_t block_id)537 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) {
538   bool result = false;
539   ir_context->get_def_use_mgr()->WhileEachUse(
540       block_id,
541       [&result](const opt::Instruction* use_instruction,
542                 uint32_t /*unused*/) -> bool {
543         switch (use_instruction->opcode()) {
544           case SpvOpLoopMerge:
545           case SpvOpSelectionMerge:
546             result = true;
547             return false;
548           default:
549             return true;
550         }
551       });
552   return result;
553 }
554 
GetLoopFromMergeBlock(opt::IRContext * ir_context,uint32_t merge_block_id)555 uint32_t GetLoopFromMergeBlock(opt::IRContext* ir_context,
556                                uint32_t merge_block_id) {
557   uint32_t result = 0;
558   ir_context->get_def_use_mgr()->WhileEachUse(
559       merge_block_id,
560       [ir_context, &result](opt::Instruction* use_instruction,
561                             uint32_t use_index) -> bool {
562         switch (use_instruction->opcode()) {
563           case SpvOpLoopMerge:
564             // The merge block operand is the first operand in OpLoopMerge.
565             if (use_index == 0) {
566               result = ir_context->get_instr_block(use_instruction)->id();
567               return false;
568             }
569             return true;
570           default:
571             return true;
572         }
573       });
574   return result;
575 }
576 
FindFunctionType(opt::IRContext * ir_context,const std::vector<uint32_t> & type_ids)577 uint32_t FindFunctionType(opt::IRContext* ir_context,
578                           const std::vector<uint32_t>& type_ids) {
579   // Look through the existing types for a match.
580   for (auto& type_or_value : ir_context->types_values()) {
581     if (type_or_value.opcode() != SpvOpTypeFunction) {
582       // We are only interested in function types.
583       continue;
584     }
585     if (type_or_value.NumInOperands() != type_ids.size()) {
586       // Not a match: different numbers of arguments.
587       continue;
588     }
589     // Check whether the return type and argument types match.
590     bool input_operands_match = true;
591     for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
592       if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
593         input_operands_match = false;
594         break;
595       }
596     }
597     if (input_operands_match) {
598       // Everything matches.
599       return type_or_value.result_id();
600     }
601   }
602   // No match was found.
603   return 0;
604 }
605 
GetFunctionType(opt::IRContext * context,const opt::Function * function)606 opt::Instruction* GetFunctionType(opt::IRContext* context,
607                                   const opt::Function* function) {
608   uint32_t type_id = function->DefInst().GetSingleWordInOperand(1);
609   return context->get_def_use_mgr()->GetDef(type_id);
610 }
611 
FindFunction(opt::IRContext * ir_context,uint32_t function_id)612 opt::Function* FindFunction(opt::IRContext* ir_context, uint32_t function_id) {
613   for (auto& function : *ir_context->module()) {
614     if (function.result_id() == function_id) {
615       return &function;
616     }
617   }
618   return nullptr;
619 }
620 
FunctionContainsOpKillOrUnreachable(const opt::Function & function)621 bool FunctionContainsOpKillOrUnreachable(const opt::Function& function) {
622   for (auto& block : function) {
623     if (block.terminator()->opcode() == SpvOpKill ||
624         block.terminator()->opcode() == SpvOpUnreachable) {
625       return true;
626     }
627   }
628   return false;
629 }
630 
FunctionIsEntryPoint(opt::IRContext * context,uint32_t function_id)631 bool FunctionIsEntryPoint(opt::IRContext* context, uint32_t function_id) {
632   for (auto& entry_point : context->module()->entry_points()) {
633     if (entry_point.GetSingleWordInOperand(1) == function_id) {
634       return true;
635     }
636   }
637   return false;
638 }
639 
IdIsAvailableAtUse(opt::IRContext * context,opt::Instruction * use_instruction,uint32_t use_input_operand_index,uint32_t id)640 bool IdIsAvailableAtUse(opt::IRContext* context,
641                         opt::Instruction* use_instruction,
642                         uint32_t use_input_operand_index, uint32_t id) {
643   assert(context->get_instr_block(use_instruction) &&
644          "|use_instruction| must be in a basic block");
645 
646   auto defining_instruction = context->get_def_use_mgr()->GetDef(id);
647   auto enclosing_function =
648       context->get_instr_block(use_instruction)->GetParent();
649   // If the id a function parameter, it needs to be associated with the
650   // function containing the use.
651   if (defining_instruction->opcode() == SpvOpFunctionParameter) {
652     return InstructionIsFunctionParameter(defining_instruction,
653                                           enclosing_function);
654   }
655   if (!context->get_instr_block(id)) {
656     // The id must be at global scope.
657     return true;
658   }
659   if (defining_instruction == use_instruction) {
660     // It is not OK for a definition to use itself.
661     return false;
662   }
663   auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function);
664   if (!dominator_analysis->IsReachable(
665           context->get_instr_block(use_instruction)) ||
666       !dominator_analysis->IsReachable(context->get_instr_block(id))) {
667     // Skip unreachable blocks.
668     return false;
669   }
670   if (use_instruction->opcode() == SpvOpPhi) {
671     // In the case where the use is an operand to OpPhi, it is actually the
672     // *parent* block associated with the operand that must be dominated by
673     // the synonym.
674     auto parent_block =
675         use_instruction->GetSingleWordInOperand(use_input_operand_index + 1);
676     return dominator_analysis->Dominates(
677         context->get_instr_block(defining_instruction)->id(), parent_block);
678   }
679   return dominator_analysis->Dominates(defining_instruction, use_instruction);
680 }
681 
IdIsAvailableBeforeInstruction(opt::IRContext * context,opt::Instruction * instruction,uint32_t id)682 bool IdIsAvailableBeforeInstruction(opt::IRContext* context,
683                                     opt::Instruction* instruction,
684                                     uint32_t id) {
685   assert(context->get_instr_block(instruction) &&
686          "|instruction| must be in a basic block");
687 
688   auto id_definition = context->get_def_use_mgr()->GetDef(id);
689   auto function_enclosing_instruction =
690       context->get_instr_block(instruction)->GetParent();
691   // If the id a function parameter, it needs to be associated with the
692   // function containing the instruction.
693   if (id_definition->opcode() == SpvOpFunctionParameter) {
694     return InstructionIsFunctionParameter(id_definition,
695                                           function_enclosing_instruction);
696   }
697   if (!context->get_instr_block(id)) {
698     // The id is at global scope.
699     return true;
700   }
701   if (id_definition == instruction) {
702     // The instruction is not available right before its own definition.
703     return false;
704   }
705   const auto* dominator_analysis =
706       context->GetDominatorAnalysis(function_enclosing_instruction);
707   if (dominator_analysis->IsReachable(context->get_instr_block(instruction)) &&
708       dominator_analysis->IsReachable(context->get_instr_block(id)) &&
709       dominator_analysis->Dominates(id_definition, instruction)) {
710     // The id's definition dominates the instruction, and both the definition
711     // and the instruction are in reachable blocks, thus the id is available at
712     // the instruction.
713     return true;
714   }
715   if (id_definition->opcode() == SpvOpVariable &&
716       function_enclosing_instruction ==
717           context->get_instr_block(id)->GetParent()) {
718     assert(!dominator_analysis->IsReachable(
719                context->get_instr_block(instruction)) &&
720            "If the instruction were in a reachable block we should already "
721            "have returned true.");
722     // The id is a variable and it is in the same function as |instruction|.
723     // This is OK despite |instruction| being unreachable.
724     return true;
725   }
726   return false;
727 }
728 
InstructionIsFunctionParameter(opt::Instruction * instruction,opt::Function * function)729 bool InstructionIsFunctionParameter(opt::Instruction* instruction,
730                                     opt::Function* function) {
731   if (instruction->opcode() != SpvOpFunctionParameter) {
732     return false;
733   }
734   bool found_parameter = false;
735   function->ForEachParam(
736       [instruction, &found_parameter](opt::Instruction* param) {
737         if (param == instruction) {
738           found_parameter = true;
739         }
740       });
741   return found_parameter;
742 }
743 
GetTypeId(opt::IRContext * context,uint32_t result_id)744 uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
745   const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
746   assert(inst && "|result_id| is invalid");
747   return inst->type_id();
748 }
749 
GetPointeeTypeIdFromPointerType(opt::Instruction * pointer_type_inst)750 uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
751   assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
752          "Precondition: |pointer_type_inst| must be OpTypePointer.");
753   return pointer_type_inst->GetSingleWordInOperand(1);
754 }
755 
GetPointeeTypeIdFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)756 uint32_t GetPointeeTypeIdFromPointerType(opt::IRContext* context,
757                                          uint32_t pointer_type_id) {
758   return GetPointeeTypeIdFromPointerType(
759       context->get_def_use_mgr()->GetDef(pointer_type_id));
760 }
761 
GetStorageClassFromPointerType(opt::Instruction * pointer_type_inst)762 SpvStorageClass GetStorageClassFromPointerType(
763     opt::Instruction* pointer_type_inst) {
764   assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
765          "Precondition: |pointer_type_inst| must be OpTypePointer.");
766   return static_cast<SpvStorageClass>(
767       pointer_type_inst->GetSingleWordInOperand(0));
768 }
769 
GetStorageClassFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)770 SpvStorageClass GetStorageClassFromPointerType(opt::IRContext* context,
771                                                uint32_t pointer_type_id) {
772   return GetStorageClassFromPointerType(
773       context->get_def_use_mgr()->GetDef(pointer_type_id));
774 }
775 
MaybeGetPointerType(opt::IRContext * context,uint32_t pointee_type_id,SpvStorageClass storage_class)776 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
777                              SpvStorageClass storage_class) {
778   for (auto& inst : context->types_values()) {
779     switch (inst.opcode()) {
780       case SpvOpTypePointer:
781         if (inst.GetSingleWordInOperand(0) == storage_class &&
782             inst.GetSingleWordInOperand(1) == pointee_type_id) {
783           return inst.result_id();
784         }
785         break;
786       default:
787         break;
788     }
789   }
790   return 0;
791 }
792 
InOperandIndexFromOperandIndex(const opt::Instruction & inst,uint32_t absolute_index)793 uint32_t InOperandIndexFromOperandIndex(const opt::Instruction& inst,
794                                         uint32_t absolute_index) {
795   // Subtract the number of non-input operands from the index
796   return absolute_index - inst.NumOperands() + inst.NumInOperands();
797 }
798 
IsNullConstantSupported(const opt::analysis::Type & type)799 bool IsNullConstantSupported(const opt::analysis::Type& type) {
800   return type.AsBool() || type.AsInteger() || type.AsFloat() ||
801          type.AsMatrix() || type.AsVector() || type.AsArray() ||
802          type.AsStruct() || type.AsPointer() || type.AsEvent() ||
803          type.AsDeviceEvent() || type.AsReserveId() || type.AsQueue();
804 }
805 
GlobalVariablesMustBeDeclaredInEntryPointInterfaces(const opt::IRContext * ir_context)806 bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
807     const opt::IRContext* ir_context) {
808   // TODO(afd): We capture the environments for which this requirement holds.
809   //  The check should be refined on demand for other target environments.
810   switch (ir_context->grammar().target_env()) {
811     case SPV_ENV_UNIVERSAL_1_0:
812     case SPV_ENV_UNIVERSAL_1_1:
813     case SPV_ENV_UNIVERSAL_1_2:
814     case SPV_ENV_UNIVERSAL_1_3:
815     case SPV_ENV_VULKAN_1_0:
816     case SPV_ENV_VULKAN_1_1:
817       return false;
818     default:
819       return true;
820   }
821 }
822 
AddVariableIdToEntryPointInterfaces(opt::IRContext * context,uint32_t id)823 void AddVariableIdToEntryPointInterfaces(opt::IRContext* context, uint32_t id) {
824   if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(context)) {
825     // Conservatively add this global to the interface of every entry point in
826     // the module.  This means that the global is available for other
827     // transformations to use.
828     //
829     // A downside of this is that the global will be in the interface even if it
830     // ends up never being used.
831     //
832     // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit
833     //  this if a more thorough approach to entry point interfaces is taken.
834     for (auto& entry_point : context->module()->entry_points()) {
835       entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {id}});
836     }
837   }
838 }
839 
AddGlobalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,SpvStorageClass storage_class,uint32_t initializer_id)840 opt::Instruction* AddGlobalVariable(opt::IRContext* context, uint32_t result_id,
841                                     uint32_t type_id,
842                                     SpvStorageClass storage_class,
843                                     uint32_t initializer_id) {
844   // Check various preconditions.
845   assert(result_id != 0 && "Result id can't be 0");
846 
847   assert((storage_class == SpvStorageClassPrivate ||
848           storage_class == SpvStorageClassWorkgroup) &&
849          "Variable's storage class must be either Private or Workgroup");
850 
851   auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
852   (void)type_inst;  // Variable becomes unused in release mode.
853   assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
854          GetStorageClassFromPointerType(type_inst) == storage_class &&
855          "Variable's type is invalid");
856 
857   if (storage_class == SpvStorageClassWorkgroup) {
858     assert(initializer_id == 0);
859   }
860 
861   if (initializer_id != 0) {
862     const auto* constant_inst =
863         context->get_def_use_mgr()->GetDef(initializer_id);
864     (void)constant_inst;  // Variable becomes unused in release mode.
865     assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
866            GetPointeeTypeIdFromPointerType(type_inst) ==
867                constant_inst->type_id() &&
868            "Initializer is invalid");
869   }
870 
871   opt::Instruction::OperandList operands = {
872       {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast<uint32_t>(storage_class)}}};
873 
874   if (initializer_id) {
875     operands.push_back({SPV_OPERAND_TYPE_ID, {initializer_id}});
876   }
877 
878   auto new_instruction = MakeUnique<opt::Instruction>(
879       context, SpvOpVariable, type_id, result_id, std::move(operands));
880   auto result = new_instruction.get();
881   context->module()->AddGlobalValue(std::move(new_instruction));
882 
883   AddVariableIdToEntryPointInterfaces(context, result_id);
884   UpdateModuleIdBound(context, result_id);
885 
886   return result;
887 }
888 
AddLocalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,uint32_t function_id,uint32_t initializer_id)889 opt::Instruction* AddLocalVariable(opt::IRContext* context, uint32_t result_id,
890                                    uint32_t type_id, uint32_t function_id,
891                                    uint32_t initializer_id) {
892   // Check various preconditions.
893   assert(result_id != 0 && "Result id can't be 0");
894 
895   auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
896   (void)type_inst;  // Variable becomes unused in release mode.
897   assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
898          GetStorageClassFromPointerType(type_inst) == SpvStorageClassFunction &&
899          "Variable's type is invalid");
900 
901   const auto* constant_inst =
902       context->get_def_use_mgr()->GetDef(initializer_id);
903   (void)constant_inst;  // Variable becomes unused in release mode.
904   assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
905          GetPointeeTypeIdFromPointerType(type_inst) ==
906              constant_inst->type_id() &&
907          "Initializer is invalid");
908 
909   auto* function = FindFunction(context, function_id);
910   assert(function && "Function id is invalid");
911 
912   auto new_instruction = MakeUnique<opt::Instruction>(
913       context, SpvOpVariable, type_id, result_id,
914       opt::Instruction::OperandList{
915           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
916           {SPV_OPERAND_TYPE_ID, {initializer_id}}});
917   auto result = new_instruction.get();
918   function->begin()->begin()->InsertBefore(std::move(new_instruction));
919 
920   UpdateModuleIdBound(context, result_id);
921 
922   return result;
923 }
924 
HasDuplicates(const std::vector<uint32_t> & arr)925 bool HasDuplicates(const std::vector<uint32_t>& arr) {
926   return std::unordered_set<uint32_t>(arr.begin(), arr.end()).size() !=
927          arr.size();
928 }
929 
IsPermutationOfRange(const std::vector<uint32_t> & arr,uint32_t lo,uint32_t hi)930 bool IsPermutationOfRange(const std::vector<uint32_t>& arr, uint32_t lo,
931                           uint32_t hi) {
932   if (arr.empty()) {
933     return lo > hi;
934   }
935 
936   if (HasDuplicates(arr)) {
937     return false;
938   }
939 
940   auto min_max = std::minmax_element(arr.begin(), arr.end());
941   return arr.size() == hi - lo + 1 && *min_max.first == lo &&
942          *min_max.second == hi;
943 }
944 
GetParameters(opt::IRContext * ir_context,uint32_t function_id)945 std::vector<opt::Instruction*> GetParameters(opt::IRContext* ir_context,
946                                              uint32_t function_id) {
947   auto* function = FindFunction(ir_context, function_id);
948   assert(function && "|function_id| is invalid");
949 
950   std::vector<opt::Instruction*> result;
951   function->ForEachParam(
952       [&result](opt::Instruction* inst) { result.push_back(inst); });
953 
954   return result;
955 }
956 
RemoveParameter(opt::IRContext * ir_context,uint32_t parameter_id)957 void RemoveParameter(opt::IRContext* ir_context, uint32_t parameter_id) {
958   auto* function = GetFunctionFromParameterId(ir_context, parameter_id);
959   assert(function && "|parameter_id| is invalid");
960   assert(!FunctionIsEntryPoint(ir_context, function->result_id()) &&
961          "Can't remove parameter from an entry point function");
962 
963   function->RemoveParameter(parameter_id);
964 
965   // We've just removed parameters from the function and cleared their memory.
966   // Make sure analyses have no dangling pointers.
967   ir_context->InvalidateAnalysesExceptFor(
968       opt::IRContext::Analysis::kAnalysisNone);
969 }
970 
GetCallers(opt::IRContext * ir_context,uint32_t function_id)971 std::vector<opt::Instruction*> GetCallers(opt::IRContext* ir_context,
972                                           uint32_t function_id) {
973   assert(FindFunction(ir_context, function_id) &&
974          "|function_id| is not a result id of a function");
975 
976   std::vector<opt::Instruction*> result;
977   ir_context->get_def_use_mgr()->ForEachUser(
978       function_id, [&result, function_id](opt::Instruction* inst) {
979         if (inst->opcode() == SpvOpFunctionCall &&
980             inst->GetSingleWordInOperand(0) == function_id) {
981           result.push_back(inst);
982         }
983       });
984 
985   return result;
986 }
987 
GetFunctionFromParameterId(opt::IRContext * ir_context,uint32_t param_id)988 opt::Function* GetFunctionFromParameterId(opt::IRContext* ir_context,
989                                           uint32_t param_id) {
990   auto* param_inst = ir_context->get_def_use_mgr()->GetDef(param_id);
991   assert(param_inst && "Parameter id is invalid");
992 
993   for (auto& function : *ir_context->module()) {
994     if (InstructionIsFunctionParameter(param_inst, &function)) {
995       return &function;
996     }
997   }
998 
999   return nullptr;
1000 }
1001 
UpdateFunctionType(opt::IRContext * ir_context,uint32_t function_id,uint32_t new_function_type_result_id,uint32_t return_type_id,const std::vector<uint32_t> & parameter_type_ids)1002 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
1003                             uint32_t new_function_type_result_id,
1004                             uint32_t return_type_id,
1005                             const std::vector<uint32_t>& parameter_type_ids) {
1006   // Check some initial constraints.
1007   assert(ir_context->get_type_mgr()->GetType(return_type_id) &&
1008          "Return type is invalid");
1009   for (auto id : parameter_type_ids) {
1010     const auto* type = ir_context->get_type_mgr()->GetType(id);
1011     (void)type;  // Make compilers happy in release mode.
1012     // Parameters can't be OpTypeVoid.
1013     assert(type && !type->AsVoid() && "Parameter has invalid type");
1014   }
1015 
1016   auto* function = FindFunction(ir_context, function_id);
1017   assert(function && "|function_id| is invalid");
1018 
1019   auto* old_function_type = GetFunctionType(ir_context, function);
1020   assert(old_function_type && "Function has invalid type");
1021 
1022   std::vector<uint32_t> operand_ids = {return_type_id};
1023   operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
1024                      parameter_type_ids.end());
1025 
1026   // A trivial case - we change nothing.
1027   if (FindFunctionType(ir_context, operand_ids) ==
1028       old_function_type->result_id()) {
1029     return old_function_type->result_id();
1030   }
1031 
1032   if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
1033       FindFunctionType(ir_context, operand_ids) == 0) {
1034     // We can change |old_function_type| only if it's used once in the module
1035     // and we are certain we won't create a duplicate as a result of the change.
1036 
1037     // Update |old_function_type| in-place.
1038     opt::Instruction::OperandList operands;
1039     for (auto id : operand_ids) {
1040       operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1041     }
1042 
1043     old_function_type->SetInOperands(std::move(operands));
1044 
1045     // |operands| may depend on result ids defined below the |old_function_type|
1046     // in the module.
1047     old_function_type->RemoveFromList();
1048     ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
1049     return old_function_type->result_id();
1050   } else {
1051     // We can't modify the |old_function_type| so we have to either use an
1052     // existing one or create a new one.
1053     auto type_id = FindOrCreateFunctionType(
1054         ir_context, new_function_type_result_id, operand_ids);
1055     assert(type_id != old_function_type->result_id() &&
1056            "We should've handled this case above");
1057 
1058     function->DefInst().SetInOperand(1, {type_id});
1059 
1060     // DefUseManager hasn't been updated yet, so if the following condition is
1061     // true, then |old_function_type| will have no users when this function
1062     // returns. We might as well remove it.
1063     if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
1064       ir_context->KillInst(old_function_type);
1065     }
1066 
1067     return type_id;
1068   }
1069 }
1070 
AddFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1071 void AddFunctionType(opt::IRContext* ir_context, uint32_t result_id,
1072                      const std::vector<uint32_t>& type_ids) {
1073   assert(result_id != 0 && "Result id can't be 0");
1074   assert(!type_ids.empty() &&
1075          "OpTypeFunction always has at least one operand - function's return "
1076          "type");
1077   assert(IsNonFunctionTypeId(ir_context, type_ids[0]) &&
1078          "Return type must not be a function");
1079 
1080   for (size_t i = 1; i < type_ids.size(); ++i) {
1081     const auto* param_type = ir_context->get_type_mgr()->GetType(type_ids[i]);
1082     (void)param_type;  // Make compiler happy in release mode.
1083     assert(param_type && !param_type->AsVoid() && !param_type->AsFunction() &&
1084            "Function parameter can't have a function or void type");
1085   }
1086 
1087   opt::Instruction::OperandList operands;
1088   operands.reserve(type_ids.size());
1089   for (auto id : type_ids) {
1090     operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1091   }
1092 
1093   ir_context->AddType(MakeUnique<opt::Instruction>(
1094       ir_context, SpvOpTypeFunction, 0, result_id, std::move(operands)));
1095 
1096   UpdateModuleIdBound(ir_context, result_id);
1097 }
1098 
FindOrCreateFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1099 uint32_t FindOrCreateFunctionType(opt::IRContext* ir_context,
1100                                   uint32_t result_id,
1101                                   const std::vector<uint32_t>& type_ids) {
1102   if (auto existing_id = FindFunctionType(ir_context, type_ids)) {
1103     return existing_id;
1104   }
1105   AddFunctionType(ir_context, result_id, type_ids);
1106   return result_id;
1107 }
1108 
MaybeGetIntegerType(opt::IRContext * ir_context,uint32_t width,bool is_signed)1109 uint32_t MaybeGetIntegerType(opt::IRContext* ir_context, uint32_t width,
1110                              bool is_signed) {
1111   opt::analysis::Integer type(width, is_signed);
1112   return ir_context->get_type_mgr()->GetId(&type);
1113 }
1114 
MaybeGetFloatType(opt::IRContext * ir_context,uint32_t width)1115 uint32_t MaybeGetFloatType(opt::IRContext* ir_context, uint32_t width) {
1116   opt::analysis::Float type(width);
1117   return ir_context->get_type_mgr()->GetId(&type);
1118 }
1119 
MaybeGetBoolType(opt::IRContext * ir_context)1120 uint32_t MaybeGetBoolType(opt::IRContext* ir_context) {
1121   opt::analysis::Bool type;
1122   return ir_context->get_type_mgr()->GetId(&type);
1123 }
1124 
MaybeGetVectorType(opt::IRContext * ir_context,uint32_t component_type_id,uint32_t element_count)1125 uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
1126                             uint32_t component_type_id,
1127                             uint32_t element_count) {
1128   const auto* component_type =
1129       ir_context->get_type_mgr()->GetType(component_type_id);
1130   assert(component_type &&
1131          (component_type->AsInteger() || component_type->AsFloat() ||
1132           component_type->AsBool()) &&
1133          "|component_type_id| is invalid");
1134   assert(element_count >= 2 && element_count <= 4 &&
1135          "Precondition: component count must be in range [2, 4].");
1136   opt::analysis::Vector type(component_type, element_count);
1137   return ir_context->get_type_mgr()->GetId(&type);
1138 }
1139 
MaybeGetStructType(opt::IRContext * ir_context,const std::vector<uint32_t> & component_type_ids)1140 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
1141                             const std::vector<uint32_t>& component_type_ids) {
1142   for (auto& type_or_value : ir_context->types_values()) {
1143     if (type_or_value.opcode() != SpvOpTypeStruct ||
1144         type_or_value.NumInOperands() !=
1145             static_cast<uint32_t>(component_type_ids.size())) {
1146       continue;
1147     }
1148     bool all_components_match = true;
1149     for (uint32_t i = 0; i < component_type_ids.size(); i++) {
1150       if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
1151         all_components_match = false;
1152         break;
1153       }
1154     }
1155     if (all_components_match) {
1156       return type_or_value.result_id();
1157     }
1158   }
1159   return 0;
1160 }
1161 
MaybeGetVoidType(opt::IRContext * ir_context)1162 uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
1163   opt::analysis::Void type;
1164   return ir_context->get_type_mgr()->GetId(&type);
1165 }
1166 
MaybeGetZeroConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,uint32_t scalar_or_composite_type_id,bool is_irrelevant)1167 uint32_t MaybeGetZeroConstant(
1168     opt::IRContext* ir_context,
1169     const TransformationContext& transformation_context,
1170     uint32_t scalar_or_composite_type_id, bool is_irrelevant) {
1171   const auto* type_inst =
1172       ir_context->get_def_use_mgr()->GetDef(scalar_or_composite_type_id);
1173   assert(type_inst && "|scalar_or_composite_type_id| is invalid");
1174 
1175   switch (type_inst->opcode()) {
1176     case SpvOpTypeBool:
1177       return MaybeGetBoolConstant(ir_context, transformation_context, false,
1178                                   is_irrelevant);
1179     case SpvOpTypeFloat:
1180     case SpvOpTypeInt: {
1181       const auto width = type_inst->GetSingleWordInOperand(0);
1182       std::vector<uint32_t> words = {0};
1183       if (width > 32) {
1184         words.push_back(0);
1185       }
1186 
1187       return MaybeGetScalarConstant(ir_context, transformation_context, words,
1188                                     scalar_or_composite_type_id, is_irrelevant);
1189     }
1190     case SpvOpTypeStruct: {
1191       std::vector<uint32_t> component_ids;
1192       for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
1193         const auto component_type_id = type_inst->GetSingleWordInOperand(i);
1194 
1195         auto component_id =
1196             MaybeGetZeroConstant(ir_context, transformation_context,
1197                                  component_type_id, is_irrelevant);
1198 
1199         if (component_id == 0 && is_irrelevant) {
1200           // Irrelevant constants can use either relevant or irrelevant
1201           // constituents.
1202           component_id = MaybeGetZeroConstant(
1203               ir_context, transformation_context, component_type_id, false);
1204         }
1205 
1206         if (component_id == 0) {
1207           return 0;
1208         }
1209 
1210         component_ids.push_back(component_id);
1211       }
1212 
1213       return MaybeGetCompositeConstant(
1214           ir_context, transformation_context, component_ids,
1215           scalar_or_composite_type_id, is_irrelevant);
1216     }
1217     case SpvOpTypeMatrix:
1218     case SpvOpTypeVector: {
1219       const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1220 
1221       auto component_id = MaybeGetZeroConstant(
1222           ir_context, transformation_context, component_type_id, is_irrelevant);
1223 
1224       if (component_id == 0 && is_irrelevant) {
1225         // Irrelevant constants can use either relevant or irrelevant
1226         // constituents.
1227         component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1228                                             component_type_id, false);
1229       }
1230 
1231       if (component_id == 0) {
1232         return 0;
1233       }
1234 
1235       const auto component_count = type_inst->GetSingleWordInOperand(1);
1236       return MaybeGetCompositeConstant(
1237           ir_context, transformation_context,
1238           std::vector<uint32_t>(component_count, component_id),
1239           scalar_or_composite_type_id, is_irrelevant);
1240     }
1241     case SpvOpTypeArray: {
1242       const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1243 
1244       auto component_id = MaybeGetZeroConstant(
1245           ir_context, transformation_context, component_type_id, is_irrelevant);
1246 
1247       if (component_id == 0 && is_irrelevant) {
1248         // Irrelevant constants can use either relevant or irrelevant
1249         // constituents.
1250         component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1251                                             component_type_id, false);
1252       }
1253 
1254       if (component_id == 0) {
1255         return 0;
1256       }
1257 
1258       return MaybeGetCompositeConstant(
1259           ir_context, transformation_context,
1260           std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
1261                                 component_id),
1262           scalar_or_composite_type_id, is_irrelevant);
1263     }
1264     default:
1265       assert(false && "Type is not supported");
1266       return 0;
1267   }
1268 }
1269 
CanCreateConstant(opt::IRContext * ir_context,uint32_t type_id)1270 bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
1271   opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
1272   assert(type_instr != nullptr && "The type must exist.");
1273   assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
1274          "A type-generating opcode was expected.");
1275   switch (type_instr->opcode()) {
1276     case SpvOpTypeBool:
1277     case SpvOpTypeInt:
1278     case SpvOpTypeFloat:
1279     case SpvOpTypeMatrix:
1280     case SpvOpTypeVector:
1281       return true;
1282     case SpvOpTypeArray:
1283       return CanCreateConstant(ir_context,
1284                                type_instr->GetSingleWordInOperand(0));
1285     case SpvOpTypeStruct:
1286       if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
1287         return false;
1288       }
1289       for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
1290         if (!CanCreateConstant(ir_context,
1291                                type_instr->GetSingleWordInOperand(index))) {
1292           return false;
1293         }
1294       }
1295       return true;
1296     default:
1297       return false;
1298   }
1299 }
1300 
MaybeGetScalarConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t scalar_type_id,bool is_irrelevant)1301 uint32_t MaybeGetScalarConstant(
1302     opt::IRContext* ir_context,
1303     const TransformationContext& transformation_context,
1304     const std::vector<uint32_t>& words, uint32_t scalar_type_id,
1305     bool is_irrelevant) {
1306   const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
1307   assert(type && "|scalar_type_id| is invalid");
1308 
1309   if (const auto* int_type = type->AsInteger()) {
1310     return MaybeGetIntegerConstant(ir_context, transformation_context, words,
1311                                    int_type->width(), int_type->IsSigned(),
1312                                    is_irrelevant);
1313   } else if (const auto* float_type = type->AsFloat()) {
1314     return MaybeGetFloatConstant(ir_context, transformation_context, words,
1315                                  float_type->width(), is_irrelevant);
1316   } else {
1317     assert(type->AsBool() && words.size() == 1 &&
1318            "|scalar_type_id| doesn't represent a scalar type");
1319     return MaybeGetBoolConstant(ir_context, transformation_context, words[0],
1320                                 is_irrelevant);
1321   }
1322 }
1323 
MaybeGetCompositeConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & component_ids,uint32_t composite_type_id,bool is_irrelevant)1324 uint32_t MaybeGetCompositeConstant(
1325     opt::IRContext* ir_context,
1326     const TransformationContext& transformation_context,
1327     const std::vector<uint32_t>& component_ids, uint32_t composite_type_id,
1328     bool is_irrelevant) {
1329   const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
1330   (void)type;  // Make compilers happy in release mode.
1331   assert(IsCompositeType(type) && "|composite_type_id| is invalid");
1332 
1333   for (const auto& inst : ir_context->types_values()) {
1334     if (inst.opcode() == SpvOpConstantComposite &&
1335         inst.type_id() == composite_type_id &&
1336         transformation_context.GetFactManager()->IdIsIrrelevant(
1337             inst.result_id()) == is_irrelevant &&
1338         inst.NumInOperands() == component_ids.size()) {
1339       bool is_match = true;
1340 
1341       for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
1342         if (inst.GetSingleWordInOperand(i) != component_ids[i]) {
1343           is_match = false;
1344           break;
1345         }
1346       }
1347 
1348       if (is_match) {
1349         return inst.result_id();
1350       }
1351     }
1352   }
1353 
1354   return 0;
1355 }
1356 
MaybeGetIntegerConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_signed,bool is_irrelevant)1357 uint32_t MaybeGetIntegerConstant(
1358     opt::IRContext* ir_context,
1359     const TransformationContext& transformation_context,
1360     const std::vector<uint32_t>& words, uint32_t width, bool is_signed,
1361     bool is_irrelevant) {
1362   if (auto type_id = MaybeGetIntegerType(ir_context, width, is_signed)) {
1363     return MaybeGetOpConstant(ir_context, transformation_context, words,
1364                               type_id, is_irrelevant);
1365   }
1366 
1367   return 0;
1368 }
1369 
MaybeGetIntegerConstantFromValueAndType(opt::IRContext * ir_context,uint32_t value,uint32_t int_type_id)1370 uint32_t MaybeGetIntegerConstantFromValueAndType(opt::IRContext* ir_context,
1371                                                  uint32_t value,
1372                                                  uint32_t int_type_id) {
1373   auto int_type_inst = ir_context->get_def_use_mgr()->GetDef(int_type_id);
1374 
1375   assert(int_type_inst && "The given type id must exist.");
1376 
1377   auto int_type = ir_context->get_type_mgr()
1378                       ->GetType(int_type_inst->result_id())
1379                       ->AsInteger();
1380 
1381   assert(int_type && int_type->width() == 32 &&
1382          "The given type id must correspond to an 32-bit integer type.");
1383 
1384   opt::analysis::IntConstant constant(int_type, {value});
1385 
1386   // Check that the constant exists in the module.
1387   if (!ir_context->get_constant_mgr()->FindConstant(&constant)) {
1388     return 0;
1389   }
1390 
1391   return ir_context->get_constant_mgr()
1392       ->GetDefiningInstruction(&constant)
1393       ->result_id();
1394 }
1395 
MaybeGetFloatConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_irrelevant)1396 uint32_t MaybeGetFloatConstant(
1397     opt::IRContext* ir_context,
1398     const TransformationContext& transformation_context,
1399     const std::vector<uint32_t>& words, uint32_t width, bool is_irrelevant) {
1400   if (auto type_id = MaybeGetFloatType(ir_context, width)) {
1401     return MaybeGetOpConstant(ir_context, transformation_context, words,
1402                               type_id, is_irrelevant);
1403   }
1404 
1405   return 0;
1406 }
1407 
MaybeGetBoolConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,bool value,bool is_irrelevant)1408 uint32_t MaybeGetBoolConstant(
1409     opt::IRContext* ir_context,
1410     const TransformationContext& transformation_context, bool value,
1411     bool is_irrelevant) {
1412   if (auto type_id = MaybeGetBoolType(ir_context)) {
1413     for (const auto& inst : ir_context->types_values()) {
1414       if (inst.opcode() == (value ? SpvOpConstantTrue : SpvOpConstantFalse) &&
1415           inst.type_id() == type_id &&
1416           transformation_context.GetFactManager()->IdIsIrrelevant(
1417               inst.result_id()) == is_irrelevant) {
1418         return inst.result_id();
1419       }
1420     }
1421   }
1422 
1423   return 0;
1424 }
1425 
IntToWords(uint64_t value,uint32_t width,bool is_signed)1426 std::vector<uint32_t> IntToWords(uint64_t value, uint32_t width,
1427                                  bool is_signed) {
1428   assert(width <= 64 && "The bit width should not be more than 64 bits");
1429 
1430   // Sign-extend or zero-extend the last |width| bits of |value|, depending on
1431   // |is_signed|.
1432   if (is_signed) {
1433     // Sign-extend by shifting left and then shifting right, interpreting the
1434     // integer as signed.
1435     value = static_cast<int64_t>(value << (64 - width)) >> (64 - width);
1436   } else {
1437     // Zero-extend by shifting left and then shifting right, interpreting the
1438     // integer as unsigned.
1439     value = (value << (64 - width)) >> (64 - width);
1440   }
1441 
1442   std::vector<uint32_t> result;
1443   result.push_back(static_cast<uint32_t>(value));
1444   if (width > 32) {
1445     result.push_back(static_cast<uint32_t>(value >> 32));
1446   }
1447   return result;
1448 }
1449 
TypesAreEqualUpToSign(opt::IRContext * ir_context,uint32_t type1_id,uint32_t type2_id)1450 bool TypesAreEqualUpToSign(opt::IRContext* ir_context, uint32_t type1_id,
1451                            uint32_t type2_id) {
1452   if (type1_id == type2_id) {
1453     return true;
1454   }
1455 
1456   auto type1 = ir_context->get_type_mgr()->GetType(type1_id);
1457   auto type2 = ir_context->get_type_mgr()->GetType(type2_id);
1458 
1459   // Integer scalar types must have the same width
1460   if (type1->AsInteger() && type2->AsInteger()) {
1461     return type1->AsInteger()->width() == type2->AsInteger()->width();
1462   }
1463 
1464   // Integer vector types must have the same number of components and their
1465   // component types must be integers with the same width.
1466   if (type1->AsVector() && type2->AsVector()) {
1467     auto component_type1 = type1->AsVector()->element_type()->AsInteger();
1468     auto component_type2 = type2->AsVector()->element_type()->AsInteger();
1469 
1470     // Only check the component count and width if they are integer.
1471     if (component_type1 && component_type2) {
1472       return type1->AsVector()->element_count() ==
1473                  type2->AsVector()->element_count() &&
1474              component_type1->width() == component_type2->width();
1475     }
1476   }
1477 
1478   // In all other cases, the types cannot be considered equal.
1479   return false;
1480 }
1481 
RepeatedUInt32PairToMap(const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> & data)1482 std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
1483     const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>& data) {
1484   std::map<uint32_t, uint32_t> result;
1485 
1486   for (const auto& entry : data) {
1487     result[entry.first()] = entry.second();
1488   }
1489 
1490   return result;
1491 }
1492 
1493 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t,uint32_t> & data)1494 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
1495   google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> result;
1496 
1497   for (const auto& entry : data) {
1498     protobufs::UInt32Pair pair;
1499     pair.set_first(entry.first);
1500     pair.set_second(entry.second);
1501     *result.Add() = std::move(pair);
1502   }
1503 
1504   return result;
1505 }
1506 
GetLastInsertBeforeInstruction(opt::IRContext * ir_context,uint32_t block_id,SpvOp opcode)1507 opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
1508                                                  uint32_t block_id,
1509                                                  SpvOp opcode) {
1510   // CFG::block uses std::map::at which throws an exception when |block_id| is
1511   // invalid. The error message is unhelpful, though. Thus, we test that
1512   // |block_id| is valid here.
1513   const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
1514   (void)label_inst;  // Make compilers happy in release mode.
1515   assert(label_inst && label_inst->opcode() == SpvOpLabel &&
1516          "|block_id| is invalid");
1517 
1518   auto* block = ir_context->cfg()->block(block_id);
1519   auto it = block->rbegin();
1520   assert(it != block->rend() && "Basic block can't be empty");
1521 
1522   if (block->GetMergeInst()) {
1523     ++it;
1524     assert(it != block->rend() &&
1525            "|block| must have at least two instructions:"
1526            "terminator and a merge instruction");
1527   }
1528 
1529   return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
1530 }
1531 
IdUseCanBeReplaced(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * use_instruction,uint32_t use_in_operand_index)1532 bool IdUseCanBeReplaced(opt::IRContext* ir_context,
1533                         const TransformationContext& transformation_context,
1534                         opt::Instruction* use_instruction,
1535                         uint32_t use_in_operand_index) {
1536   if (spvOpcodeIsAccessChain(use_instruction->opcode()) &&
1537       use_in_operand_index > 0) {
1538     // A replacement for an irrelevant index in OpAccessChain must be clamped
1539     // first.
1540     if (transformation_context.GetFactManager()->IdIsIrrelevant(
1541             use_instruction->GetSingleWordInOperand(use_in_operand_index))) {
1542       return false;
1543     }
1544 
1545     // This is an access chain index.  If the (sub-)object being accessed by the
1546     // given index has struct type then we cannot replace the use, as it needs
1547     // to be an OpConstant.
1548 
1549     // Get the top-level composite type that is being accessed.
1550     auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef(
1551         use_instruction->GetSingleWordInOperand(0));
1552     auto pointer_type =
1553         ir_context->get_type_mgr()->GetType(object_being_accessed->type_id());
1554     assert(pointer_type->AsPointer());
1555     auto composite_type_being_accessed =
1556         pointer_type->AsPointer()->pointee_type();
1557 
1558     // Now walk the access chain, tracking the type of each sub-object of the
1559     // composite that is traversed, until the index of interest is reached.
1560     for (uint32_t index_in_operand = 1; index_in_operand < use_in_operand_index;
1561          index_in_operand++) {
1562       // For vectors, matrices and arrays, getting the type of the sub-object is
1563       // trivial. For the struct case, the sub-object type is field-sensitive,
1564       // and depends on the constant index that is used.
1565       if (composite_type_being_accessed->AsVector()) {
1566         composite_type_being_accessed =
1567             composite_type_being_accessed->AsVector()->element_type();
1568       } else if (composite_type_being_accessed->AsMatrix()) {
1569         composite_type_being_accessed =
1570             composite_type_being_accessed->AsMatrix()->element_type();
1571       } else if (composite_type_being_accessed->AsArray()) {
1572         composite_type_being_accessed =
1573             composite_type_being_accessed->AsArray()->element_type();
1574       } else if (composite_type_being_accessed->AsRuntimeArray()) {
1575         composite_type_being_accessed =
1576             composite_type_being_accessed->AsRuntimeArray()->element_type();
1577       } else {
1578         assert(composite_type_being_accessed->AsStruct());
1579         auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef(
1580             use_instruction->GetSingleWordInOperand(index_in_operand));
1581         assert(constant_index_instruction->opcode() == SpvOpConstant);
1582         uint32_t member_index =
1583             constant_index_instruction->GetSingleWordInOperand(0);
1584         composite_type_being_accessed =
1585             composite_type_being_accessed->AsStruct()
1586                 ->element_types()[member_index];
1587       }
1588     }
1589 
1590     // We have found the composite type being accessed by the index we are
1591     // considering replacing. If it is a struct, then we cannot do the
1592     // replacement as struct indices must be constants.
1593     if (composite_type_being_accessed->AsStruct()) {
1594       return false;
1595     }
1596   }
1597 
1598   if (use_instruction->opcode() == SpvOpFunctionCall &&
1599       use_in_operand_index > 0) {
1600     // This is a function call argument.  It is not allowed to have pointer
1601     // type.
1602 
1603     // Get the definition of the function being called.
1604     auto function = ir_context->get_def_use_mgr()->GetDef(
1605         use_instruction->GetSingleWordInOperand(0));
1606     // From the function definition, get the function type.
1607     auto function_type = ir_context->get_def_use_mgr()->GetDef(
1608         function->GetSingleWordInOperand(1));
1609     // OpTypeFunction's 0-th input operand is the function return type, and the
1610     // function argument types follow. Because the arguments to OpFunctionCall
1611     // start from input operand 1, we can use |use_in_operand_index| to get the
1612     // type associated with this function argument.
1613     auto parameter_type = ir_context->get_type_mgr()->GetType(
1614         function_type->GetSingleWordInOperand(use_in_operand_index));
1615     if (parameter_type->AsPointer()) {
1616       return false;
1617     }
1618   }
1619 
1620   if (use_instruction->opcode() == SpvOpImageTexelPointer &&
1621       use_in_operand_index == 2) {
1622     // The OpImageTexelPointer instruction has a Sample parameter that in some
1623     // situations must be an id for the value 0.  To guard against disrupting
1624     // that requirement, we do not replace this argument to that instruction.
1625     return false;
1626   }
1627 
1628   return true;
1629 }
1630 
MembersHaveBuiltInDecoration(opt::IRContext * ir_context,uint32_t struct_type_id)1631 bool MembersHaveBuiltInDecoration(opt::IRContext* ir_context,
1632                                   uint32_t struct_type_id) {
1633   const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(struct_type_id);
1634   assert(type_inst && type_inst->opcode() == SpvOpTypeStruct &&
1635          "|struct_type_id| is not a result id of an OpTypeStruct");
1636 
1637   uint32_t builtin_count = 0;
1638   ir_context->get_def_use_mgr()->ForEachUser(
1639       type_inst,
1640       [struct_type_id, &builtin_count](const opt::Instruction* user) {
1641         if (user->opcode() == SpvOpMemberDecorate &&
1642             user->GetSingleWordInOperand(0) == struct_type_id &&
1643             static_cast<SpvDecoration>(user->GetSingleWordInOperand(2)) ==
1644                 SpvDecorationBuiltIn) {
1645           ++builtin_count;
1646         }
1647       });
1648 
1649   assert((builtin_count == 0 || builtin_count == type_inst->NumInOperands()) &&
1650          "The module is invalid: either none or all of the members of "
1651          "|struct_type_id| may be builtin");
1652 
1653   return builtin_count != 0;
1654 }
1655 
HasBlockOrBufferBlockDecoration(opt::IRContext * ir_context,uint32_t id)1656 bool HasBlockOrBufferBlockDecoration(opt::IRContext* ir_context, uint32_t id) {
1657   for (auto decoration : {SpvDecorationBlock, SpvDecorationBufferBlock}) {
1658     if (!ir_context->get_decoration_mgr()->WhileEachDecoration(
1659             id, decoration, [](const opt::Instruction & /*unused*/) -> bool {
1660               return false;
1661             })) {
1662       return true;
1663     }
1664   }
1665   return false;
1666 }
1667 
SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(opt::BasicBlock * block_to_split,opt::Instruction * split_before)1668 bool SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(
1669     opt::BasicBlock* block_to_split, opt::Instruction* split_before) {
1670   std::set<uint32_t> sampled_image_result_ids;
1671   bool before_split = true;
1672 
1673   // Check all the instructions in the block to split.
1674   for (auto& instruction : *block_to_split) {
1675     if (&instruction == &*split_before) {
1676       before_split = false;
1677     }
1678     if (before_split) {
1679       // If the instruction comes before the split and its opcode is
1680       // OpSampledImage, record its result id.
1681       if (instruction.opcode() == SpvOpSampledImage) {
1682         sampled_image_result_ids.insert(instruction.result_id());
1683       }
1684     } else {
1685       // If the instruction comes after the split, check if ids
1686       // corresponding to OpSampledImage instructions defined before the split
1687       // are used, and return true if they are.
1688       if (!instruction.WhileEachInId(
1689               [&sampled_image_result_ids](uint32_t* id) -> bool {
1690                 return !sampled_image_result_ids.count(*id);
1691               })) {
1692         return true;
1693       }
1694     }
1695   }
1696 
1697   // No usage that would be separated from the definition has been found.
1698   return false;
1699 }
1700 
InstructionHasNoSideEffects(const opt::Instruction & instruction)1701 bool InstructionHasNoSideEffects(const opt::Instruction& instruction) {
1702   switch (instruction.opcode()) {
1703     case SpvOpUndef:
1704     case SpvOpAccessChain:
1705     case SpvOpInBoundsAccessChain:
1706     case SpvOpArrayLength:
1707     case SpvOpVectorExtractDynamic:
1708     case SpvOpVectorInsertDynamic:
1709     case SpvOpVectorShuffle:
1710     case SpvOpCompositeConstruct:
1711     case SpvOpCompositeExtract:
1712     case SpvOpCompositeInsert:
1713     case SpvOpCopyObject:
1714     case SpvOpTranspose:
1715     case SpvOpConvertFToU:
1716     case SpvOpConvertFToS:
1717     case SpvOpConvertSToF:
1718     case SpvOpConvertUToF:
1719     case SpvOpUConvert:
1720     case SpvOpSConvert:
1721     case SpvOpFConvert:
1722     case SpvOpQuantizeToF16:
1723     case SpvOpSatConvertSToU:
1724     case SpvOpSatConvertUToS:
1725     case SpvOpBitcast:
1726     case SpvOpSNegate:
1727     case SpvOpFNegate:
1728     case SpvOpIAdd:
1729     case SpvOpFAdd:
1730     case SpvOpISub:
1731     case SpvOpFSub:
1732     case SpvOpIMul:
1733     case SpvOpFMul:
1734     case SpvOpUDiv:
1735     case SpvOpSDiv:
1736     case SpvOpFDiv:
1737     case SpvOpUMod:
1738     case SpvOpSRem:
1739     case SpvOpSMod:
1740     case SpvOpFRem:
1741     case SpvOpFMod:
1742     case SpvOpVectorTimesScalar:
1743     case SpvOpMatrixTimesScalar:
1744     case SpvOpVectorTimesMatrix:
1745     case SpvOpMatrixTimesVector:
1746     case SpvOpMatrixTimesMatrix:
1747     case SpvOpOuterProduct:
1748     case SpvOpDot:
1749     case SpvOpIAddCarry:
1750     case SpvOpISubBorrow:
1751     case SpvOpUMulExtended:
1752     case SpvOpSMulExtended:
1753     case SpvOpAny:
1754     case SpvOpAll:
1755     case SpvOpIsNan:
1756     case SpvOpIsInf:
1757     case SpvOpIsFinite:
1758     case SpvOpIsNormal:
1759     case SpvOpSignBitSet:
1760     case SpvOpLessOrGreater:
1761     case SpvOpOrdered:
1762     case SpvOpUnordered:
1763     case SpvOpLogicalEqual:
1764     case SpvOpLogicalNotEqual:
1765     case SpvOpLogicalOr:
1766     case SpvOpLogicalAnd:
1767     case SpvOpLogicalNot:
1768     case SpvOpSelect:
1769     case SpvOpIEqual:
1770     case SpvOpINotEqual:
1771     case SpvOpUGreaterThan:
1772     case SpvOpSGreaterThan:
1773     case SpvOpUGreaterThanEqual:
1774     case SpvOpSGreaterThanEqual:
1775     case SpvOpULessThan:
1776     case SpvOpSLessThan:
1777     case SpvOpULessThanEqual:
1778     case SpvOpSLessThanEqual:
1779     case SpvOpFOrdEqual:
1780     case SpvOpFUnordEqual:
1781     case SpvOpFOrdNotEqual:
1782     case SpvOpFUnordNotEqual:
1783     case SpvOpFOrdLessThan:
1784     case SpvOpFUnordLessThan:
1785     case SpvOpFOrdGreaterThan:
1786     case SpvOpFUnordGreaterThan:
1787     case SpvOpFOrdLessThanEqual:
1788     case SpvOpFUnordLessThanEqual:
1789     case SpvOpFOrdGreaterThanEqual:
1790     case SpvOpFUnordGreaterThanEqual:
1791     case SpvOpShiftRightLogical:
1792     case SpvOpShiftRightArithmetic:
1793     case SpvOpShiftLeftLogical:
1794     case SpvOpBitwiseOr:
1795     case SpvOpBitwiseXor:
1796     case SpvOpBitwiseAnd:
1797     case SpvOpNot:
1798     case SpvOpBitFieldInsert:
1799     case SpvOpBitFieldSExtract:
1800     case SpvOpBitFieldUExtract:
1801     case SpvOpBitReverse:
1802     case SpvOpBitCount:
1803     case SpvOpCopyLogical:
1804     case SpvOpPhi:
1805     case SpvOpPtrEqual:
1806     case SpvOpPtrNotEqual:
1807       return true;
1808     default:
1809       return false;
1810   }
1811 }
1812 
GetReachableReturnBlocks(opt::IRContext * ir_context,uint32_t function_id)1813 std::set<uint32_t> GetReachableReturnBlocks(opt::IRContext* ir_context,
1814                                             uint32_t function_id) {
1815   auto function = ir_context->GetFunction(function_id);
1816   assert(function && "The function |function_id| must exist.");
1817 
1818   std::set<uint32_t> result;
1819 
1820   ir_context->cfg()->ForEachBlockInPostOrder(function->entry().get(),
1821                                              [&result](opt::BasicBlock* block) {
1822                                                if (block->IsReturn()) {
1823                                                  result.emplace(block->id());
1824                                                }
1825                                              });
1826 
1827   return result;
1828 }
1829 
NewTerminatorPreservesDominationRules(opt::IRContext * ir_context,uint32_t block_id,opt::Instruction new_terminator)1830 bool NewTerminatorPreservesDominationRules(opt::IRContext* ir_context,
1831                                            uint32_t block_id,
1832                                            opt::Instruction new_terminator) {
1833   auto* mutated_block = MaybeFindBlock(ir_context, block_id);
1834   assert(mutated_block && "|block_id| is invalid");
1835 
1836   ChangeTerminatorRAII change_terminator_raii(mutated_block,
1837                                               std::move(new_terminator));
1838   opt::DominatorAnalysis dominator_analysis;
1839   dominator_analysis.InitializeTree(*ir_context->cfg(),
1840                                     mutated_block->GetParent());
1841 
1842   // Check that each dominator appears before each dominated block.
1843   std::unordered_map<uint32_t, size_t> positions;
1844   for (const auto& block : *mutated_block->GetParent()) {
1845     positions[block.id()] = positions.size();
1846   }
1847 
1848   std::queue<uint32_t> q({mutated_block->GetParent()->begin()->id()});
1849   std::unordered_set<uint32_t> visited;
1850   while (!q.empty()) {
1851     auto block = q.front();
1852     q.pop();
1853     visited.insert(block);
1854 
1855     auto success = ir_context->cfg()->block(block)->WhileEachSuccessorLabel(
1856         [&positions, &visited, &dominator_analysis, block, &q](uint32_t id) {
1857           if (id == block) {
1858             // Handle the case when loop header and continue target are the same
1859             // block.
1860             return true;
1861           }
1862 
1863           if (dominator_analysis.Dominates(block, id) &&
1864               positions[block] > positions[id]) {
1865             // |block| dominates |id| but appears after |id| - violates
1866             // domination rules.
1867             return false;
1868           }
1869 
1870           if (!visited.count(id)) {
1871             q.push(id);
1872           }
1873 
1874           return true;
1875         });
1876 
1877     if (!success) {
1878       return false;
1879     }
1880   }
1881 
1882   // For each instruction in the |block->GetParent()| function check whether
1883   // all its dependencies satisfy domination rules (i.e. all id operands
1884   // dominate that instruction).
1885   for (const auto& block : *mutated_block->GetParent()) {
1886     if (!dominator_analysis.IsReachable(&block)) {
1887       // If some block is not reachable then we don't need to worry about the
1888       // preservation of domination rules for its instructions.
1889       continue;
1890     }
1891 
1892     for (const auto& inst : block) {
1893       for (uint32_t i = 0; i < inst.NumInOperands();
1894            i += inst.opcode() == SpvOpPhi ? 2 : 1) {
1895         const auto& operand = inst.GetInOperand(i);
1896         if (!spvIsInIdType(operand.type)) {
1897           continue;
1898         }
1899 
1900         if (MaybeFindBlock(ir_context, operand.words[0])) {
1901           // Ignore operands that refer to OpLabel instructions.
1902           continue;
1903         }
1904 
1905         const auto* dependency_block =
1906             ir_context->get_instr_block(operand.words[0]);
1907         if (!dependency_block) {
1908           // A global instruction always dominates all instructions in any
1909           // function.
1910           continue;
1911         }
1912 
1913         auto domination_target_id = inst.opcode() == SpvOpPhi
1914                                         ? inst.GetSingleWordInOperand(i + 1)
1915                                         : block.id();
1916 
1917         if (!dominator_analysis.Dominates(dependency_block->id(),
1918                                           domination_target_id)) {
1919           return false;
1920         }
1921       }
1922     }
1923   }
1924 
1925   return true;
1926 }
1927 
1928 }  // namespace fuzzerutil
1929 }  // namespace fuzz
1930 }  // namespace spvtools
1931