• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/fuzzer_util.h"
16 
17 #include <algorithm>
18 #include <unordered_set>
19 
20 #include "source/opt/build_module.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 
25 namespace fuzzerutil {
26 namespace {
27 
28 // A utility class that uses RAII to change and restore the terminator
29 // instruction of the |block|.
30 class ChangeTerminatorRAII {
31  public:
ChangeTerminatorRAII(opt::BasicBlock * block,opt::Instruction new_terminator)32   explicit ChangeTerminatorRAII(opt::BasicBlock* block,
33                                 opt::Instruction new_terminator)
34       : block_(block), old_terminator_(std::move(*block->terminator())) {
35     *block_->terminator() = std::move(new_terminator);
36   }
37 
~ChangeTerminatorRAII()38   ~ChangeTerminatorRAII() {
39     *block_->terminator() = std::move(old_terminator_);
40   }
41 
42  private:
43   opt::BasicBlock* block_;
44   opt::Instruction old_terminator_;
45 };
46 
MaybeGetOpConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t type_id,bool is_irrelevant)47 uint32_t MaybeGetOpConstant(opt::IRContext* ir_context,
48                             const TransformationContext& transformation_context,
49                             const std::vector<uint32_t>& words,
50                             uint32_t type_id, bool is_irrelevant) {
51   for (const auto& inst : ir_context->types_values()) {
52     if (inst.opcode() == SpvOpConstant && inst.type_id() == type_id &&
53         inst.GetInOperand(0).words == words &&
54         transformation_context.GetFactManager()->IdIsIrrelevant(
55             inst.result_id()) == is_irrelevant) {
56       return inst.result_id();
57     }
58   }
59 
60   return 0;
61 }
62 
63 }  // namespace
64 
65 const spvtools::MessageConsumer kSilentMessageConsumer =
66     [](spv_message_level_t, const char*, const spv_position_t&,
__anonbc26db0a0202(spv_message_level_t, const char*, const spv_position_t&, const char*) 67        const char*) -> void {};
68 
BuildIRContext(spv_target_env target_env,const spvtools::MessageConsumer & message_consumer,const std::vector<uint32_t> & binary_in,spv_validator_options validator_options,std::unique_ptr<spvtools::opt::IRContext> * ir_context)69 bool BuildIRContext(spv_target_env target_env,
70                     const spvtools::MessageConsumer& message_consumer,
71                     const std::vector<uint32_t>& binary_in,
72                     spv_validator_options validator_options,
73                     std::unique_ptr<spvtools::opt::IRContext>* ir_context) {
74   SpirvTools tools(target_env);
75   tools.SetMessageConsumer(message_consumer);
76   if (!tools.IsValid()) {
77     message_consumer(SPV_MSG_ERROR, nullptr, {},
78                      "Failed to create SPIRV-Tools interface; stopping.");
79     return false;
80   }
81 
82   // Initial binary should be valid.
83   if (!tools.Validate(binary_in.data(), binary_in.size(), validator_options)) {
84     message_consumer(SPV_MSG_ERROR, nullptr, {},
85                      "Initial binary is invalid; stopping.");
86     return false;
87   }
88 
89   // Build the module from the input binary.
90   auto result = BuildModule(target_env, message_consumer, binary_in.data(),
91                             binary_in.size());
92   assert(result && "IRContext must be valid");
93   *ir_context = std::move(result);
94   return true;
95 }
96 
IsFreshId(opt::IRContext * context,uint32_t id)97 bool IsFreshId(opt::IRContext* context, uint32_t id) {
98   return !context->get_def_use_mgr()->GetDef(id);
99 }
100 
UpdateModuleIdBound(opt::IRContext * context,uint32_t id)101 void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
102   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
103   //  case where the maximum id bound is reached.
104   context->module()->SetIdBound(
105       std::max(context->module()->id_bound(), id + 1));
106 }
107 
MaybeFindBlock(opt::IRContext * context,uint32_t maybe_block_id)108 opt::BasicBlock* MaybeFindBlock(opt::IRContext* context,
109                                 uint32_t maybe_block_id) {
110   auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
111   if (inst == nullptr) {
112     // No instruction defining this id was found.
113     return nullptr;
114   }
115   if (inst->opcode() != SpvOpLabel) {
116     // The instruction defining the id is not a label, so it cannot be a block
117     // id.
118     return nullptr;
119   }
120   return context->cfg()->block(maybe_block_id);
121 }
122 
PhiIdsOkForNewEdge(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)123 bool PhiIdsOkForNewEdge(
124     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
125     const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
126   if (bb_from->IsSuccessor(bb_to)) {
127     // There is already an edge from |from_block| to |to_block|, so there is
128     // no need to extend OpPhi instructions.  Do not allow phi ids to be
129     // present. This might turn out to be too strict; perhaps it would be OK
130     // just to ignore the ids in this case.
131     return phi_ids.empty();
132   }
133   // The edge would add a previously non-existent edge from |from_block| to
134   // |to_block|, so we go through the given phi ids and check that they exactly
135   // match the OpPhi instructions in |to_block|.
136   uint32_t phi_index = 0;
137   // An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
138   // makes sense here because we need to increment |phi_index| for each OpPhi
139   // instruction.
140   for (auto& inst : *bb_to) {
141     if (inst.opcode() != SpvOpPhi) {
142       // The OpPhi instructions all occur at the start of the block; if we find
143       // a non-OpPhi then we have seen them all.
144       break;
145     }
146     if (phi_index == static_cast<uint32_t>(phi_ids.size())) {
147       // Not enough phi ids have been provided to account for the OpPhi
148       // instructions.
149       return false;
150     }
151     // Look for an instruction defining the next phi id.
152     opt::Instruction* phi_extension =
153         context->get_def_use_mgr()->GetDef(phi_ids[phi_index]);
154     if (!phi_extension) {
155       // The id given to extend this OpPhi does not exist.
156       return false;
157     }
158     if (phi_extension->type_id() != inst.type_id()) {
159       // The instruction given to extend this OpPhi either does not have a type
160       // or its type does not match that of the OpPhi.
161       return false;
162     }
163 
164     if (context->get_instr_block(phi_extension)) {
165       // The instruction defining the phi id has an associated block (i.e., it
166       // is not a global value).  Check whether its definition dominates the
167       // exit of |from_block|.
168       auto dominator_analysis =
169           context->GetDominatorAnalysis(bb_from->GetParent());
170       if (!dominator_analysis->Dominates(phi_extension,
171                                          bb_from->terminator())) {
172         // The given id is no good as its definition does not dominate the exit
173         // of |from_block|
174         return false;
175       }
176     }
177     phi_index++;
178   }
179   // We allow some of the ids provided for extending OpPhi instructions to be
180   // unused.  Their presence does no harm, and requiring a perfect match may
181   // make transformations less likely to cleanly apply.
182   return true;
183 }
184 
CreateUnreachableEdgeInstruction(opt::IRContext * ir_context,uint32_t bb_from_id,uint32_t bb_to_id,uint32_t bool_id)185 opt::Instruction CreateUnreachableEdgeInstruction(opt::IRContext* ir_context,
186                                                   uint32_t bb_from_id,
187                                                   uint32_t bb_to_id,
188                                                   uint32_t bool_id) {
189   const auto* bb_from = MaybeFindBlock(ir_context, bb_from_id);
190   assert(bb_from && "|bb_from_id| is invalid");
191   assert(MaybeFindBlock(ir_context, bb_to_id) && "|bb_to_id| is invalid");
192   assert(bb_from->terminator()->opcode() == SpvOpBranch &&
193          "Precondition on terminator of bb_from is not satisfied");
194 
195   // Get the id of the boolean constant to be used as the condition.
196   auto condition_inst = ir_context->get_def_use_mgr()->GetDef(bool_id);
197   assert(condition_inst &&
198          (condition_inst->opcode() == SpvOpConstantTrue ||
199           condition_inst->opcode() == SpvOpConstantFalse) &&
200          "|bool_id| is invalid");
201 
202   auto condition_value = condition_inst->opcode() == SpvOpConstantTrue;
203   auto successor_id = bb_from->terminator()->GetSingleWordInOperand(0);
204 
205   // Add the dead branch, by turning OpBranch into OpBranchConditional, and
206   // ordering the targets depending on whether the given boolean corresponds to
207   // true or false.
208   return opt::Instruction(
209       ir_context, SpvOpBranchConditional, 0, 0,
210       {{SPV_OPERAND_TYPE_ID, {bool_id}},
211        {SPV_OPERAND_TYPE_ID, {condition_value ? successor_id : bb_to_id}},
212        {SPV_OPERAND_TYPE_ID, {condition_value ? bb_to_id : successor_id}}});
213 }
214 
AddUnreachableEdgeAndUpdateOpPhis(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,uint32_t bool_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)215 void AddUnreachableEdgeAndUpdateOpPhis(
216     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
217     uint32_t bool_id,
218     const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
219   assert(PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) &&
220          "Precondition on phi_ids is not satisfied");
221 
222   const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
223   *bb_from->terminator() = CreateUnreachableEdgeInstruction(
224       context, bb_from->id(), bb_to->id(), bool_id);
225 
226   // Update OpPhi instructions in the target block if this branch adds a
227   // previously non-existent edge from source to target.
228   if (!from_to_edge_already_exists) {
229     uint32_t phi_index = 0;
230     for (auto& inst : *bb_to) {
231       if (inst.opcode() != SpvOpPhi) {
232         break;
233       }
234       assert(phi_index < static_cast<uint32_t>(phi_ids.size()) &&
235              "There should be at least one phi id per OpPhi instruction.");
236       inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}});
237       inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}});
238       phi_index++;
239     }
240   }
241 }
242 
BlockIsBackEdge(opt::IRContext * context,uint32_t block_id,uint32_t loop_header_id)243 bool BlockIsBackEdge(opt::IRContext* context, uint32_t block_id,
244                      uint32_t loop_header_id) {
245   auto block = context->cfg()->block(block_id);
246   auto loop_header = context->cfg()->block(loop_header_id);
247 
248   // |block| and |loop_header| must be defined, |loop_header| must be in fact
249   // loop header and |block| must branch to it.
250   if (!(block && loop_header && loop_header->IsLoopHeader() &&
251         block->IsSuccessor(loop_header))) {
252     return false;
253   }
254 
255   // |block| must be reachable and be dominated by |loop_header|.
256   opt::DominatorAnalysis* dominator_analysis =
257       context->GetDominatorAnalysis(loop_header->GetParent());
258   return context->IsReachable(*block) &&
259          dominator_analysis->Dominates(loop_header, block);
260 }
261 
BlockIsInLoopContinueConstruct(opt::IRContext * context,uint32_t block_id,uint32_t maybe_loop_header_id)262 bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
263                                     uint32_t maybe_loop_header_id) {
264   // We deem a block to be part of a loop's continue construct if the loop's
265   // continue target dominates the block.
266   auto containing_construct_block = context->cfg()->block(maybe_loop_header_id);
267   if (containing_construct_block->IsLoopHeader()) {
268     auto continue_target = containing_construct_block->ContinueBlockId();
269     if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
270             ->Dominates(continue_target, block_id)) {
271       return true;
272     }
273   }
274   return false;
275 }
276 
GetIteratorForInstruction(opt::BasicBlock * block,const opt::Instruction * inst)277 opt::BasicBlock::iterator GetIteratorForInstruction(
278     opt::BasicBlock* block, const opt::Instruction* inst) {
279   for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
280     if (inst == &*inst_it) {
281       return inst_it;
282     }
283   }
284   return block->end();
285 }
286 
CanInsertOpcodeBeforeInstruction(SpvOp opcode,const opt::BasicBlock::iterator & instruction_in_block)287 bool CanInsertOpcodeBeforeInstruction(
288     SpvOp opcode, const opt::BasicBlock::iterator& instruction_in_block) {
289   if (instruction_in_block->PreviousNode() &&
290       (instruction_in_block->PreviousNode()->opcode() == SpvOpLoopMerge ||
291        instruction_in_block->PreviousNode()->opcode() == SpvOpSelectionMerge)) {
292     // We cannot insert directly after a merge instruction.
293     return false;
294   }
295   if (opcode != SpvOpVariable &&
296       instruction_in_block->opcode() == SpvOpVariable) {
297     // We cannot insert a non-OpVariable instruction directly before a
298     // variable; variables in a function must be contiguous in the entry block.
299     return false;
300   }
301   // We cannot insert a non-OpPhi instruction directly before an OpPhi, because
302   // OpPhi instructions need to be contiguous at the start of a block.
303   return opcode == SpvOpPhi || instruction_in_block->opcode() != SpvOpPhi;
304 }
305 
CanMakeSynonymOf(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * inst)306 bool CanMakeSynonymOf(opt::IRContext* ir_context,
307                       const TransformationContext& transformation_context,
308                       opt::Instruction* inst) {
309   if (inst->opcode() == SpvOpSampledImage) {
310     // The SPIR-V data rules say that only very specific instructions may
311     // may consume the result id of an OpSampledImage, and this excludes the
312     // instructions that are used for making synonyms.
313     return false;
314   }
315   if (!inst->HasResultId()) {
316     // We can only make a synonym of an instruction that generates an id.
317     return false;
318   }
319   if (transformation_context.GetFactManager()->IdIsIrrelevant(
320           inst->result_id())) {
321     // An irrelevant id can't be a synonym of anything.
322     return false;
323   }
324   if (!inst->type_id()) {
325     // We can only make a synonym of an instruction that has a type.
326     return false;
327   }
328   auto type_inst = ir_context->get_def_use_mgr()->GetDef(inst->type_id());
329   if (type_inst->opcode() == SpvOpTypeVoid) {
330     // We only make synonyms of instructions that define objects, and an object
331     // cannot have void type.
332     return false;
333   }
334   if (type_inst->opcode() == SpvOpTypePointer) {
335     switch (inst->opcode()) {
336       case SpvOpConstantNull:
337       case SpvOpUndef:
338         // We disallow making synonyms of null or undefined pointers.  This is
339         // to provide the property that if the original shader exhibited no bad
340         // pointer accesses, the transformed shader will not either.
341         return false;
342       default:
343         break;
344     }
345   }
346 
347   // We do not make synonyms of objects that have decorations: if the synonym is
348   // not decorated analogously, using the original object vs. its synonymous
349   // form may not be equivalent.
350   return ir_context->get_decoration_mgr()
351       ->GetDecorationsFor(inst->result_id(), true)
352       .empty();
353 }
354 
IsCompositeType(const opt::analysis::Type * type)355 bool IsCompositeType(const opt::analysis::Type* type) {
356   return type && (type->AsArray() || type->AsMatrix() || type->AsStruct() ||
357                   type->AsVector());
358 }
359 
RepeatedFieldToVector(const google::protobuf::RepeatedField<uint32_t> & repeated_field)360 std::vector<uint32_t> RepeatedFieldToVector(
361     const google::protobuf::RepeatedField<uint32_t>& repeated_field) {
362   std::vector<uint32_t> result;
363   for (auto i : repeated_field) {
364     result.push_back(i);
365   }
366   return result;
367 }
368 
WalkOneCompositeTypeIndex(opt::IRContext * context,uint32_t base_object_type_id,uint32_t index)369 uint32_t WalkOneCompositeTypeIndex(opt::IRContext* context,
370                                    uint32_t base_object_type_id,
371                                    uint32_t index) {
372   auto should_be_composite_type =
373       context->get_def_use_mgr()->GetDef(base_object_type_id);
374   assert(should_be_composite_type && "The type should exist.");
375   switch (should_be_composite_type->opcode()) {
376     case SpvOpTypeArray: {
377       auto array_length = GetArraySize(*should_be_composite_type, context);
378       if (array_length == 0 || index >= array_length) {
379         return 0;
380       }
381       return should_be_composite_type->GetSingleWordInOperand(0);
382     }
383     case SpvOpTypeMatrix:
384     case SpvOpTypeVector: {
385       auto count = should_be_composite_type->GetSingleWordInOperand(1);
386       if (index >= count) {
387         return 0;
388       }
389       return should_be_composite_type->GetSingleWordInOperand(0);
390     }
391     case SpvOpTypeStruct: {
392       if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
393         return 0;
394       }
395       return should_be_composite_type->GetSingleWordInOperand(index);
396     }
397     default:
398       return 0;
399   }
400 }
401 
WalkCompositeTypeIndices(opt::IRContext * context,uint32_t base_object_type_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & indices)402 uint32_t WalkCompositeTypeIndices(
403     opt::IRContext* context, uint32_t base_object_type_id,
404     const google::protobuf::RepeatedField<google::protobuf::uint32>& indices) {
405   uint32_t sub_object_type_id = base_object_type_id;
406   for (auto index : indices) {
407     sub_object_type_id =
408         WalkOneCompositeTypeIndex(context, sub_object_type_id, index);
409     if (!sub_object_type_id) {
410       return 0;
411     }
412   }
413   return sub_object_type_id;
414 }
415 
GetNumberOfStructMembers(const opt::Instruction & struct_type_instruction)416 uint32_t GetNumberOfStructMembers(
417     const opt::Instruction& struct_type_instruction) {
418   assert(struct_type_instruction.opcode() == SpvOpTypeStruct &&
419          "An OpTypeStruct instruction is required here.");
420   return struct_type_instruction.NumInOperands();
421 }
422 
GetArraySize(const opt::Instruction & array_type_instruction,opt::IRContext * context)423 uint32_t GetArraySize(const opt::Instruction& array_type_instruction,
424                       opt::IRContext* context) {
425   auto array_length_constant =
426       context->get_constant_mgr()
427           ->GetConstantFromInst(context->get_def_use_mgr()->GetDef(
428               array_type_instruction.GetSingleWordInOperand(1)))
429           ->AsIntConstant();
430   if (array_length_constant->words().size() != 1) {
431     return 0;
432   }
433   return array_length_constant->GetU32();
434 }
435 
GetBoundForCompositeIndex(const opt::Instruction & composite_type_inst,opt::IRContext * ir_context)436 uint32_t GetBoundForCompositeIndex(const opt::Instruction& composite_type_inst,
437                                    opt::IRContext* ir_context) {
438   switch (composite_type_inst.opcode()) {
439     case SpvOpTypeArray:
440       return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
441     case SpvOpTypeMatrix:
442     case SpvOpTypeVector:
443       return composite_type_inst.GetSingleWordInOperand(1);
444     case SpvOpTypeStruct: {
445       return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
446     }
447     case SpvOpTypeRuntimeArray:
448       assert(false &&
449              "GetBoundForCompositeIndex should not be invoked with an "
450              "OpTypeRuntimeArray, which does not have a static bound.");
451       return 0;
452     default:
453       assert(false && "Unknown composite type.");
454       return 0;
455   }
456 }
457 
IsValid(const opt::IRContext * context,spv_validator_options validator_options,MessageConsumer consumer)458 bool IsValid(const opt::IRContext* context,
459              spv_validator_options validator_options,
460              MessageConsumer consumer) {
461   std::vector<uint32_t> binary;
462   context->module()->ToBinary(&binary, false);
463   SpirvTools tools(context->grammar().target_env());
464   tools.SetMessageConsumer(std::move(consumer));
465   return tools.Validate(binary.data(), binary.size(), validator_options);
466 }
467 
IsValidAndWellFormed(const opt::IRContext * ir_context,spv_validator_options validator_options,MessageConsumer consumer)468 bool IsValidAndWellFormed(const opt::IRContext* ir_context,
469                           spv_validator_options validator_options,
470                           MessageConsumer consumer) {
471   if (!IsValid(ir_context, validator_options, consumer)) {
472     // Expression to dump |ir_context| to /data/temp/shader.spv:
473     //    DumpShader(ir_context, "/data/temp/shader.spv")
474     consumer(SPV_MSG_INFO, nullptr, {},
475              "Module is invalid (set a breakpoint to inspect).");
476     return false;
477   }
478   // Check that all blocks in the module have appropriate parent functions.
479   for (auto& function : *ir_context->module()) {
480     for (auto& block : function) {
481       if (block.GetParent() == nullptr) {
482         std::stringstream ss;
483         ss << "Block " << block.id() << " has no parent; its parent should be "
484            << function.result_id() << " (set a breakpoint to inspect).";
485         consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
486         return false;
487       }
488       if (block.GetParent() != &function) {
489         std::stringstream ss;
490         ss << "Block " << block.id() << " should have parent "
491            << function.result_id() << " but instead has parent "
492            << block.GetParent() << " (set a breakpoint to inspect).";
493         consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
494         return false;
495       }
496     }
497   }
498 
499   // Check that all instructions have distinct unique ids.  We map each unique
500   // id to the first instruction it is observed to be associated with so that
501   // if we encounter a duplicate we have access to the previous instruction -
502   // this is a useful aid to debugging.
503   std::unordered_map<uint32_t, opt::Instruction*> unique_ids;
504   bool found_duplicate = false;
505   ir_context->module()->ForEachInst([&consumer, &found_duplicate,
506                                      &unique_ids](opt::Instruction* inst) {
507     if (unique_ids.count(inst->unique_id()) != 0) {
508       consumer(SPV_MSG_INFO, nullptr, {},
509                "Two instructions have the same unique id (set a breakpoint to "
510                "inspect).");
511       found_duplicate = true;
512     }
513     unique_ids.insert({inst->unique_id(), inst});
514   });
515   return !found_duplicate;
516 }
517 
CloneIRContext(opt::IRContext * context)518 std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) {
519   std::vector<uint32_t> binary;
520   context->module()->ToBinary(&binary, false);
521   return BuildModule(context->grammar().target_env(), nullptr, binary.data(),
522                      binary.size());
523 }
524 
IsNonFunctionTypeId(opt::IRContext * ir_context,uint32_t id)525 bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id) {
526   auto type = ir_context->get_type_mgr()->GetType(id);
527   return type && !type->AsFunction();
528 }
529 
IsMergeOrContinue(opt::IRContext * ir_context,uint32_t block_id)530 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) {
531   bool result = false;
532   ir_context->get_def_use_mgr()->WhileEachUse(
533       block_id,
534       [&result](const opt::Instruction* use_instruction,
535                 uint32_t /*unused*/) -> bool {
536         switch (use_instruction->opcode()) {
537           case SpvOpLoopMerge:
538           case SpvOpSelectionMerge:
539             result = true;
540             return false;
541           default:
542             return true;
543         }
544       });
545   return result;
546 }
547 
GetLoopFromMergeBlock(opt::IRContext * ir_context,uint32_t merge_block_id)548 uint32_t GetLoopFromMergeBlock(opt::IRContext* ir_context,
549                                uint32_t merge_block_id) {
550   uint32_t result = 0;
551   ir_context->get_def_use_mgr()->WhileEachUse(
552       merge_block_id,
553       [ir_context, &result](opt::Instruction* use_instruction,
554                             uint32_t use_index) -> bool {
555         switch (use_instruction->opcode()) {
556           case SpvOpLoopMerge:
557             // The merge block operand is the first operand in OpLoopMerge.
558             if (use_index == 0) {
559               result = ir_context->get_instr_block(use_instruction)->id();
560               return false;
561             }
562             return true;
563           default:
564             return true;
565         }
566       });
567   return result;
568 }
569 
FindFunctionType(opt::IRContext * ir_context,const std::vector<uint32_t> & type_ids)570 uint32_t FindFunctionType(opt::IRContext* ir_context,
571                           const std::vector<uint32_t>& type_ids) {
572   // Look through the existing types for a match.
573   for (auto& type_or_value : ir_context->types_values()) {
574     if (type_or_value.opcode() != SpvOpTypeFunction) {
575       // We are only interested in function types.
576       continue;
577     }
578     if (type_or_value.NumInOperands() != type_ids.size()) {
579       // Not a match: different numbers of arguments.
580       continue;
581     }
582     // Check whether the return type and argument types match.
583     bool input_operands_match = true;
584     for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
585       if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
586         input_operands_match = false;
587         break;
588       }
589     }
590     if (input_operands_match) {
591       // Everything matches.
592       return type_or_value.result_id();
593     }
594   }
595   // No match was found.
596   return 0;
597 }
598 
GetFunctionType(opt::IRContext * context,const opt::Function * function)599 opt::Instruction* GetFunctionType(opt::IRContext* context,
600                                   const opt::Function* function) {
601   uint32_t type_id = function->DefInst().GetSingleWordInOperand(1);
602   return context->get_def_use_mgr()->GetDef(type_id);
603 }
604 
FindFunction(opt::IRContext * ir_context,uint32_t function_id)605 opt::Function* FindFunction(opt::IRContext* ir_context, uint32_t function_id) {
606   for (auto& function : *ir_context->module()) {
607     if (function.result_id() == function_id) {
608       return &function;
609     }
610   }
611   return nullptr;
612 }
613 
FunctionContainsOpKillOrUnreachable(const opt::Function & function)614 bool FunctionContainsOpKillOrUnreachable(const opt::Function& function) {
615   for (auto& block : function) {
616     if (block.terminator()->opcode() == SpvOpKill ||
617         block.terminator()->opcode() == SpvOpUnreachable) {
618       return true;
619     }
620   }
621   return false;
622 }
623 
FunctionIsEntryPoint(opt::IRContext * context,uint32_t function_id)624 bool FunctionIsEntryPoint(opt::IRContext* context, uint32_t function_id) {
625   for (auto& entry_point : context->module()->entry_points()) {
626     if (entry_point.GetSingleWordInOperand(1) == function_id) {
627       return true;
628     }
629   }
630   return false;
631 }
632 
IdIsAvailableAtUse(opt::IRContext * context,opt::Instruction * use_instruction,uint32_t use_input_operand_index,uint32_t id)633 bool IdIsAvailableAtUse(opt::IRContext* context,
634                         opt::Instruction* use_instruction,
635                         uint32_t use_input_operand_index, uint32_t id) {
636   assert(context->get_instr_block(use_instruction) &&
637          "|use_instruction| must be in a basic block");
638 
639   auto defining_instruction = context->get_def_use_mgr()->GetDef(id);
640   auto enclosing_function =
641       context->get_instr_block(use_instruction)->GetParent();
642   // If the id a function parameter, it needs to be associated with the
643   // function containing the use.
644   if (defining_instruction->opcode() == SpvOpFunctionParameter) {
645     return InstructionIsFunctionParameter(defining_instruction,
646                                           enclosing_function);
647   }
648   if (!context->get_instr_block(id)) {
649     // The id must be at global scope.
650     return true;
651   }
652   if (defining_instruction == use_instruction) {
653     // It is not OK for a definition to use itself.
654     return false;
655   }
656   if (!context->IsReachable(*context->get_instr_block(use_instruction)) ||
657       !context->IsReachable(*context->get_instr_block(id))) {
658     // Skip unreachable blocks.
659     return false;
660   }
661   auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function);
662   if (use_instruction->opcode() == SpvOpPhi) {
663     // In the case where the use is an operand to OpPhi, it is actually the
664     // *parent* block associated with the operand that must be dominated by
665     // the synonym.
666     auto parent_block =
667         use_instruction->GetSingleWordInOperand(use_input_operand_index + 1);
668     return dominator_analysis->Dominates(
669         context->get_instr_block(defining_instruction)->id(), parent_block);
670   }
671   return dominator_analysis->Dominates(defining_instruction, use_instruction);
672 }
673 
IdIsAvailableBeforeInstruction(opt::IRContext * context,opt::Instruction * instruction,uint32_t id)674 bool IdIsAvailableBeforeInstruction(opt::IRContext* context,
675                                     opt::Instruction* instruction,
676                                     uint32_t id) {
677   assert(context->get_instr_block(instruction) &&
678          "|instruction| must be in a basic block");
679 
680   auto id_definition = context->get_def_use_mgr()->GetDef(id);
681   auto function_enclosing_instruction =
682       context->get_instr_block(instruction)->GetParent();
683   // If the id a function parameter, it needs to be associated with the
684   // function containing the instruction.
685   if (id_definition->opcode() == SpvOpFunctionParameter) {
686     return InstructionIsFunctionParameter(id_definition,
687                                           function_enclosing_instruction);
688   }
689   if (!context->get_instr_block(id)) {
690     // The id is at global scope.
691     return true;
692   }
693   if (id_definition == instruction) {
694     // The instruction is not available right before its own definition.
695     return false;
696   }
697   const auto* dominator_analysis =
698       context->GetDominatorAnalysis(function_enclosing_instruction);
699   if (context->IsReachable(*context->get_instr_block(instruction)) &&
700       context->IsReachable(*context->get_instr_block(id)) &&
701       dominator_analysis->Dominates(id_definition, instruction)) {
702     // The id's definition dominates the instruction, and both the definition
703     // and the instruction are in reachable blocks, thus the id is available at
704     // the instruction.
705     return true;
706   }
707   if (id_definition->opcode() == SpvOpVariable &&
708       function_enclosing_instruction ==
709           context->get_instr_block(id)->GetParent()) {
710     assert(!context->IsReachable(*context->get_instr_block(instruction)) &&
711            "If the instruction were in a reachable block we should already "
712            "have returned true.");
713     // The id is a variable and it is in the same function as |instruction|.
714     // This is OK despite |instruction| being unreachable.
715     return true;
716   }
717   return false;
718 }
719 
InstructionIsFunctionParameter(opt::Instruction * instruction,opt::Function * function)720 bool InstructionIsFunctionParameter(opt::Instruction* instruction,
721                                     opt::Function* function) {
722   if (instruction->opcode() != SpvOpFunctionParameter) {
723     return false;
724   }
725   bool found_parameter = false;
726   function->ForEachParam(
727       [instruction, &found_parameter](opt::Instruction* param) {
728         if (param == instruction) {
729           found_parameter = true;
730         }
731       });
732   return found_parameter;
733 }
734 
GetTypeId(opt::IRContext * context,uint32_t result_id)735 uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
736   const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
737   assert(inst && "|result_id| is invalid");
738   return inst->type_id();
739 }
740 
GetPointeeTypeIdFromPointerType(opt::Instruction * pointer_type_inst)741 uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
742   assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
743          "Precondition: |pointer_type_inst| must be OpTypePointer.");
744   return pointer_type_inst->GetSingleWordInOperand(1);
745 }
746 
GetPointeeTypeIdFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)747 uint32_t GetPointeeTypeIdFromPointerType(opt::IRContext* context,
748                                          uint32_t pointer_type_id) {
749   return GetPointeeTypeIdFromPointerType(
750       context->get_def_use_mgr()->GetDef(pointer_type_id));
751 }
752 
GetStorageClassFromPointerType(opt::Instruction * pointer_type_inst)753 SpvStorageClass GetStorageClassFromPointerType(
754     opt::Instruction* pointer_type_inst) {
755   assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
756          "Precondition: |pointer_type_inst| must be OpTypePointer.");
757   return static_cast<SpvStorageClass>(
758       pointer_type_inst->GetSingleWordInOperand(0));
759 }
760 
GetStorageClassFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)761 SpvStorageClass GetStorageClassFromPointerType(opt::IRContext* context,
762                                                uint32_t pointer_type_id) {
763   return GetStorageClassFromPointerType(
764       context->get_def_use_mgr()->GetDef(pointer_type_id));
765 }
766 
MaybeGetPointerType(opt::IRContext * context,uint32_t pointee_type_id,SpvStorageClass storage_class)767 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
768                              SpvStorageClass storage_class) {
769   for (auto& inst : context->types_values()) {
770     switch (inst.opcode()) {
771       case SpvOpTypePointer:
772         if (inst.GetSingleWordInOperand(0) == storage_class &&
773             inst.GetSingleWordInOperand(1) == pointee_type_id) {
774           return inst.result_id();
775         }
776         break;
777       default:
778         break;
779     }
780   }
781   return 0;
782 }
783 
InOperandIndexFromOperandIndex(const opt::Instruction & inst,uint32_t absolute_index)784 uint32_t InOperandIndexFromOperandIndex(const opt::Instruction& inst,
785                                         uint32_t absolute_index) {
786   // Subtract the number of non-input operands from the index
787   return absolute_index - inst.NumOperands() + inst.NumInOperands();
788 }
789 
IsNullConstantSupported(const opt::analysis::Type & type)790 bool IsNullConstantSupported(const opt::analysis::Type& type) {
791   return type.AsBool() || type.AsInteger() || type.AsFloat() ||
792          type.AsMatrix() || type.AsVector() || type.AsArray() ||
793          type.AsStruct() || type.AsPointer() || type.AsEvent() ||
794          type.AsDeviceEvent() || type.AsReserveId() || type.AsQueue();
795 }
796 
GlobalVariablesMustBeDeclaredInEntryPointInterfaces(const opt::IRContext * ir_context)797 bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
798     const opt::IRContext* ir_context) {
799   // TODO(afd): We capture the environments for which this requirement holds.
800   //  The check should be refined on demand for other target environments.
801   switch (ir_context->grammar().target_env()) {
802     case SPV_ENV_UNIVERSAL_1_0:
803     case SPV_ENV_UNIVERSAL_1_1:
804     case SPV_ENV_UNIVERSAL_1_2:
805     case SPV_ENV_UNIVERSAL_1_3:
806     case SPV_ENV_VULKAN_1_0:
807     case SPV_ENV_VULKAN_1_1:
808       return false;
809     default:
810       return true;
811   }
812 }
813 
AddVariableIdToEntryPointInterfaces(opt::IRContext * context,uint32_t id)814 void AddVariableIdToEntryPointInterfaces(opt::IRContext* context, uint32_t id) {
815   if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(context)) {
816     // Conservatively add this global to the interface of every entry point in
817     // the module.  This means that the global is available for other
818     // transformations to use.
819     //
820     // A downside of this is that the global will be in the interface even if it
821     // ends up never being used.
822     //
823     // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit
824     //  this if a more thorough approach to entry point interfaces is taken.
825     for (auto& entry_point : context->module()->entry_points()) {
826       entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {id}});
827     }
828   }
829 }
830 
AddGlobalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,SpvStorageClass storage_class,uint32_t initializer_id)831 opt::Instruction* AddGlobalVariable(opt::IRContext* context, uint32_t result_id,
832                                     uint32_t type_id,
833                                     SpvStorageClass storage_class,
834                                     uint32_t initializer_id) {
835   // Check various preconditions.
836   assert(result_id != 0 && "Result id can't be 0");
837 
838   assert((storage_class == SpvStorageClassPrivate ||
839           storage_class == SpvStorageClassWorkgroup) &&
840          "Variable's storage class must be either Private or Workgroup");
841 
842   auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
843   (void)type_inst;  // Variable becomes unused in release mode.
844   assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
845          GetStorageClassFromPointerType(type_inst) == storage_class &&
846          "Variable's type is invalid");
847 
848   if (storage_class == SpvStorageClassWorkgroup) {
849     assert(initializer_id == 0);
850   }
851 
852   if (initializer_id != 0) {
853     const auto* constant_inst =
854         context->get_def_use_mgr()->GetDef(initializer_id);
855     (void)constant_inst;  // Variable becomes unused in release mode.
856     assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
857            GetPointeeTypeIdFromPointerType(type_inst) ==
858                constant_inst->type_id() &&
859            "Initializer is invalid");
860   }
861 
862   opt::Instruction::OperandList operands = {
863       {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast<uint32_t>(storage_class)}}};
864 
865   if (initializer_id) {
866     operands.push_back({SPV_OPERAND_TYPE_ID, {initializer_id}});
867   }
868 
869   auto new_instruction = MakeUnique<opt::Instruction>(
870       context, SpvOpVariable, type_id, result_id, std::move(operands));
871   auto result = new_instruction.get();
872   context->module()->AddGlobalValue(std::move(new_instruction));
873 
874   AddVariableIdToEntryPointInterfaces(context, result_id);
875   UpdateModuleIdBound(context, result_id);
876 
877   return result;
878 }
879 
AddLocalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,uint32_t function_id,uint32_t initializer_id)880 opt::Instruction* AddLocalVariable(opt::IRContext* context, uint32_t result_id,
881                                    uint32_t type_id, uint32_t function_id,
882                                    uint32_t initializer_id) {
883   // Check various preconditions.
884   assert(result_id != 0 && "Result id can't be 0");
885 
886   auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
887   (void)type_inst;  // Variable becomes unused in release mode.
888   assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
889          GetStorageClassFromPointerType(type_inst) == SpvStorageClassFunction &&
890          "Variable's type is invalid");
891 
892   const auto* constant_inst =
893       context->get_def_use_mgr()->GetDef(initializer_id);
894   (void)constant_inst;  // Variable becomes unused in release mode.
895   assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
896          GetPointeeTypeIdFromPointerType(type_inst) ==
897              constant_inst->type_id() &&
898          "Initializer is invalid");
899 
900   auto* function = FindFunction(context, function_id);
901   assert(function && "Function id is invalid");
902 
903   auto new_instruction = MakeUnique<opt::Instruction>(
904       context, SpvOpVariable, type_id, result_id,
905       opt::Instruction::OperandList{
906           {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
907           {SPV_OPERAND_TYPE_ID, {initializer_id}}});
908   auto result = new_instruction.get();
909   function->begin()->begin()->InsertBefore(std::move(new_instruction));
910 
911   UpdateModuleIdBound(context, result_id);
912 
913   return result;
914 }
915 
HasDuplicates(const std::vector<uint32_t> & arr)916 bool HasDuplicates(const std::vector<uint32_t>& arr) {
917   return std::unordered_set<uint32_t>(arr.begin(), arr.end()).size() !=
918          arr.size();
919 }
920 
IsPermutationOfRange(const std::vector<uint32_t> & arr,uint32_t lo,uint32_t hi)921 bool IsPermutationOfRange(const std::vector<uint32_t>& arr, uint32_t lo,
922                           uint32_t hi) {
923   if (arr.empty()) {
924     return lo > hi;
925   }
926 
927   if (HasDuplicates(arr)) {
928     return false;
929   }
930 
931   auto min_max = std::minmax_element(arr.begin(), arr.end());
932   return arr.size() == hi - lo + 1 && *min_max.first == lo &&
933          *min_max.second == hi;
934 }
935 
GetParameters(opt::IRContext * ir_context,uint32_t function_id)936 std::vector<opt::Instruction*> GetParameters(opt::IRContext* ir_context,
937                                              uint32_t function_id) {
938   auto* function = FindFunction(ir_context, function_id);
939   assert(function && "|function_id| is invalid");
940 
941   std::vector<opt::Instruction*> result;
942   function->ForEachParam(
943       [&result](opt::Instruction* inst) { result.push_back(inst); });
944 
945   return result;
946 }
947 
RemoveParameter(opt::IRContext * ir_context,uint32_t parameter_id)948 void RemoveParameter(opt::IRContext* ir_context, uint32_t parameter_id) {
949   auto* function = GetFunctionFromParameterId(ir_context, parameter_id);
950   assert(function && "|parameter_id| is invalid");
951   assert(!FunctionIsEntryPoint(ir_context, function->result_id()) &&
952          "Can't remove parameter from an entry point function");
953 
954   function->RemoveParameter(parameter_id);
955 
956   // We've just removed parameters from the function and cleared their memory.
957   // Make sure analyses have no dangling pointers.
958   ir_context->InvalidateAnalysesExceptFor(
959       opt::IRContext::Analysis::kAnalysisNone);
960 }
961 
GetCallers(opt::IRContext * ir_context,uint32_t function_id)962 std::vector<opt::Instruction*> GetCallers(opt::IRContext* ir_context,
963                                           uint32_t function_id) {
964   assert(FindFunction(ir_context, function_id) &&
965          "|function_id| is not a result id of a function");
966 
967   std::vector<opt::Instruction*> result;
968   ir_context->get_def_use_mgr()->ForEachUser(
969       function_id, [&result, function_id](opt::Instruction* inst) {
970         if (inst->opcode() == SpvOpFunctionCall &&
971             inst->GetSingleWordInOperand(0) == function_id) {
972           result.push_back(inst);
973         }
974       });
975 
976   return result;
977 }
978 
GetFunctionFromParameterId(opt::IRContext * ir_context,uint32_t param_id)979 opt::Function* GetFunctionFromParameterId(opt::IRContext* ir_context,
980                                           uint32_t param_id) {
981   auto* param_inst = ir_context->get_def_use_mgr()->GetDef(param_id);
982   assert(param_inst && "Parameter id is invalid");
983 
984   for (auto& function : *ir_context->module()) {
985     if (InstructionIsFunctionParameter(param_inst, &function)) {
986       return &function;
987     }
988   }
989 
990   return nullptr;
991 }
992 
UpdateFunctionType(opt::IRContext * ir_context,uint32_t function_id,uint32_t new_function_type_result_id,uint32_t return_type_id,const std::vector<uint32_t> & parameter_type_ids)993 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
994                             uint32_t new_function_type_result_id,
995                             uint32_t return_type_id,
996                             const std::vector<uint32_t>& parameter_type_ids) {
997   // Check some initial constraints.
998   assert(ir_context->get_type_mgr()->GetType(return_type_id) &&
999          "Return type is invalid");
1000   for (auto id : parameter_type_ids) {
1001     const auto* type = ir_context->get_type_mgr()->GetType(id);
1002     (void)type;  // Make compilers happy in release mode.
1003     // Parameters can't be OpTypeVoid.
1004     assert(type && !type->AsVoid() && "Parameter has invalid type");
1005   }
1006 
1007   auto* function = FindFunction(ir_context, function_id);
1008   assert(function && "|function_id| is invalid");
1009 
1010   auto* old_function_type = GetFunctionType(ir_context, function);
1011   assert(old_function_type && "Function has invalid type");
1012 
1013   std::vector<uint32_t> operand_ids = {return_type_id};
1014   operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
1015                      parameter_type_ids.end());
1016 
1017   // A trivial case - we change nothing.
1018   if (FindFunctionType(ir_context, operand_ids) ==
1019       old_function_type->result_id()) {
1020     return old_function_type->result_id();
1021   }
1022 
1023   if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
1024       FindFunctionType(ir_context, operand_ids) == 0) {
1025     // We can change |old_function_type| only if it's used once in the module
1026     // and we are certain we won't create a duplicate as a result of the change.
1027 
1028     // Update |old_function_type| in-place.
1029     opt::Instruction::OperandList operands;
1030     for (auto id : operand_ids) {
1031       operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1032     }
1033 
1034     old_function_type->SetInOperands(std::move(operands));
1035 
1036     // |operands| may depend on result ids defined below the |old_function_type|
1037     // in the module.
1038     old_function_type->RemoveFromList();
1039     ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
1040     return old_function_type->result_id();
1041   } else {
1042     // We can't modify the |old_function_type| so we have to either use an
1043     // existing one or create a new one.
1044     auto type_id = FindOrCreateFunctionType(
1045         ir_context, new_function_type_result_id, operand_ids);
1046     assert(type_id != old_function_type->result_id() &&
1047            "We should've handled this case above");
1048 
1049     function->DefInst().SetInOperand(1, {type_id});
1050 
1051     // DefUseManager hasn't been updated yet, so if the following condition is
1052     // true, then |old_function_type| will have no users when this function
1053     // returns. We might as well remove it.
1054     if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
1055       ir_context->KillInst(old_function_type);
1056     }
1057 
1058     return type_id;
1059   }
1060 }
1061 
AddFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1062 void AddFunctionType(opt::IRContext* ir_context, uint32_t result_id,
1063                      const std::vector<uint32_t>& type_ids) {
1064   assert(result_id != 0 && "Result id can't be 0");
1065   assert(!type_ids.empty() &&
1066          "OpTypeFunction always has at least one operand - function's return "
1067          "type");
1068   assert(IsNonFunctionTypeId(ir_context, type_ids[0]) &&
1069          "Return type must not be a function");
1070 
1071   for (size_t i = 1; i < type_ids.size(); ++i) {
1072     const auto* param_type = ir_context->get_type_mgr()->GetType(type_ids[i]);
1073     (void)param_type;  // Make compiler happy in release mode.
1074     assert(param_type && !param_type->AsVoid() && !param_type->AsFunction() &&
1075            "Function parameter can't have a function or void type");
1076   }
1077 
1078   opt::Instruction::OperandList operands;
1079   operands.reserve(type_ids.size());
1080   for (auto id : type_ids) {
1081     operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1082   }
1083 
1084   ir_context->AddType(MakeUnique<opt::Instruction>(
1085       ir_context, SpvOpTypeFunction, 0, result_id, std::move(operands)));
1086 
1087   UpdateModuleIdBound(ir_context, result_id);
1088 }
1089 
FindOrCreateFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1090 uint32_t FindOrCreateFunctionType(opt::IRContext* ir_context,
1091                                   uint32_t result_id,
1092                                   const std::vector<uint32_t>& type_ids) {
1093   if (auto existing_id = FindFunctionType(ir_context, type_ids)) {
1094     return existing_id;
1095   }
1096   AddFunctionType(ir_context, result_id, type_ids);
1097   return result_id;
1098 }
1099 
MaybeGetIntegerType(opt::IRContext * ir_context,uint32_t width,bool is_signed)1100 uint32_t MaybeGetIntegerType(opt::IRContext* ir_context, uint32_t width,
1101                              bool is_signed) {
1102   opt::analysis::Integer type(width, is_signed);
1103   return ir_context->get_type_mgr()->GetId(&type);
1104 }
1105 
MaybeGetFloatType(opt::IRContext * ir_context,uint32_t width)1106 uint32_t MaybeGetFloatType(opt::IRContext* ir_context, uint32_t width) {
1107   opt::analysis::Float type(width);
1108   return ir_context->get_type_mgr()->GetId(&type);
1109 }
1110 
MaybeGetBoolType(opt::IRContext * ir_context)1111 uint32_t MaybeGetBoolType(opt::IRContext* ir_context) {
1112   opt::analysis::Bool type;
1113   return ir_context->get_type_mgr()->GetId(&type);
1114 }
1115 
MaybeGetVectorType(opt::IRContext * ir_context,uint32_t component_type_id,uint32_t element_count)1116 uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
1117                             uint32_t component_type_id,
1118                             uint32_t element_count) {
1119   const auto* component_type =
1120       ir_context->get_type_mgr()->GetType(component_type_id);
1121   assert(component_type &&
1122          (component_type->AsInteger() || component_type->AsFloat() ||
1123           component_type->AsBool()) &&
1124          "|component_type_id| is invalid");
1125   assert(element_count >= 2 && element_count <= 4 &&
1126          "Precondition: component count must be in range [2, 4].");
1127   opt::analysis::Vector type(component_type, element_count);
1128   return ir_context->get_type_mgr()->GetId(&type);
1129 }
1130 
MaybeGetStructType(opt::IRContext * ir_context,const std::vector<uint32_t> & component_type_ids)1131 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
1132                             const std::vector<uint32_t>& component_type_ids) {
1133   for (auto& type_or_value : ir_context->types_values()) {
1134     if (type_or_value.opcode() != SpvOpTypeStruct ||
1135         type_or_value.NumInOperands() !=
1136             static_cast<uint32_t>(component_type_ids.size())) {
1137       continue;
1138     }
1139     bool all_components_match = true;
1140     for (uint32_t i = 0; i < component_type_ids.size(); i++) {
1141       if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
1142         all_components_match = false;
1143         break;
1144       }
1145     }
1146     if (all_components_match) {
1147       return type_or_value.result_id();
1148     }
1149   }
1150   return 0;
1151 }
1152 
MaybeGetVoidType(opt::IRContext * ir_context)1153 uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
1154   opt::analysis::Void type;
1155   return ir_context->get_type_mgr()->GetId(&type);
1156 }
1157 
MaybeGetZeroConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,uint32_t scalar_or_composite_type_id,bool is_irrelevant)1158 uint32_t MaybeGetZeroConstant(
1159     opt::IRContext* ir_context,
1160     const TransformationContext& transformation_context,
1161     uint32_t scalar_or_composite_type_id, bool is_irrelevant) {
1162   const auto* type_inst =
1163       ir_context->get_def_use_mgr()->GetDef(scalar_or_composite_type_id);
1164   assert(type_inst && "|scalar_or_composite_type_id| is invalid");
1165 
1166   switch (type_inst->opcode()) {
1167     case SpvOpTypeBool:
1168       return MaybeGetBoolConstant(ir_context, transformation_context, false,
1169                                   is_irrelevant);
1170     case SpvOpTypeFloat:
1171     case SpvOpTypeInt: {
1172       const auto width = type_inst->GetSingleWordInOperand(0);
1173       std::vector<uint32_t> words = {0};
1174       if (width > 32) {
1175         words.push_back(0);
1176       }
1177 
1178       return MaybeGetScalarConstant(ir_context, transformation_context, words,
1179                                     scalar_or_composite_type_id, is_irrelevant);
1180     }
1181     case SpvOpTypeStruct: {
1182       std::vector<uint32_t> component_ids;
1183       for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
1184         const auto component_type_id = type_inst->GetSingleWordInOperand(i);
1185 
1186         auto component_id =
1187             MaybeGetZeroConstant(ir_context, transformation_context,
1188                                  component_type_id, is_irrelevant);
1189 
1190         if (component_id == 0 && is_irrelevant) {
1191           // Irrelevant constants can use either relevant or irrelevant
1192           // constituents.
1193           component_id = MaybeGetZeroConstant(
1194               ir_context, transformation_context, component_type_id, false);
1195         }
1196 
1197         if (component_id == 0) {
1198           return 0;
1199         }
1200 
1201         component_ids.push_back(component_id);
1202       }
1203 
1204       return MaybeGetCompositeConstant(
1205           ir_context, transformation_context, component_ids,
1206           scalar_or_composite_type_id, is_irrelevant);
1207     }
1208     case SpvOpTypeMatrix:
1209     case SpvOpTypeVector: {
1210       const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1211 
1212       auto component_id = MaybeGetZeroConstant(
1213           ir_context, transformation_context, component_type_id, is_irrelevant);
1214 
1215       if (component_id == 0 && is_irrelevant) {
1216         // Irrelevant constants can use either relevant or irrelevant
1217         // constituents.
1218         component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1219                                             component_type_id, false);
1220       }
1221 
1222       if (component_id == 0) {
1223         return 0;
1224       }
1225 
1226       const auto component_count = type_inst->GetSingleWordInOperand(1);
1227       return MaybeGetCompositeConstant(
1228           ir_context, transformation_context,
1229           std::vector<uint32_t>(component_count, component_id),
1230           scalar_or_composite_type_id, is_irrelevant);
1231     }
1232     case SpvOpTypeArray: {
1233       const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1234 
1235       auto component_id = MaybeGetZeroConstant(
1236           ir_context, transformation_context, component_type_id, is_irrelevant);
1237 
1238       if (component_id == 0 && is_irrelevant) {
1239         // Irrelevant constants can use either relevant or irrelevant
1240         // constituents.
1241         component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1242                                             component_type_id, false);
1243       }
1244 
1245       if (component_id == 0) {
1246         return 0;
1247       }
1248 
1249       return MaybeGetCompositeConstant(
1250           ir_context, transformation_context,
1251           std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
1252                                 component_id),
1253           scalar_or_composite_type_id, is_irrelevant);
1254     }
1255     default:
1256       assert(false && "Type is not supported");
1257       return 0;
1258   }
1259 }
1260 
CanCreateConstant(opt::IRContext * ir_context,uint32_t type_id)1261 bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
1262   opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
1263   assert(type_instr != nullptr && "The type must exist.");
1264   assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
1265          "A type-generating opcode was expected.");
1266   switch (type_instr->opcode()) {
1267     case SpvOpTypeBool:
1268     case SpvOpTypeInt:
1269     case SpvOpTypeFloat:
1270     case SpvOpTypeMatrix:
1271     case SpvOpTypeVector:
1272       return true;
1273     case SpvOpTypeArray:
1274       return CanCreateConstant(ir_context,
1275                                type_instr->GetSingleWordInOperand(0));
1276     case SpvOpTypeStruct:
1277       if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
1278         return false;
1279       }
1280       for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
1281         if (!CanCreateConstant(ir_context,
1282                                type_instr->GetSingleWordInOperand(index))) {
1283           return false;
1284         }
1285       }
1286       return true;
1287     default:
1288       return false;
1289   }
1290 }
1291 
MaybeGetScalarConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t scalar_type_id,bool is_irrelevant)1292 uint32_t MaybeGetScalarConstant(
1293     opt::IRContext* ir_context,
1294     const TransformationContext& transformation_context,
1295     const std::vector<uint32_t>& words, uint32_t scalar_type_id,
1296     bool is_irrelevant) {
1297   const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
1298   assert(type && "|scalar_type_id| is invalid");
1299 
1300   if (const auto* int_type = type->AsInteger()) {
1301     return MaybeGetIntegerConstant(ir_context, transformation_context, words,
1302                                    int_type->width(), int_type->IsSigned(),
1303                                    is_irrelevant);
1304   } else if (const auto* float_type = type->AsFloat()) {
1305     return MaybeGetFloatConstant(ir_context, transformation_context, words,
1306                                  float_type->width(), is_irrelevant);
1307   } else {
1308     assert(type->AsBool() && words.size() == 1 &&
1309            "|scalar_type_id| doesn't represent a scalar type");
1310     return MaybeGetBoolConstant(ir_context, transformation_context, words[0],
1311                                 is_irrelevant);
1312   }
1313 }
1314 
MaybeGetCompositeConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & component_ids,uint32_t composite_type_id,bool is_irrelevant)1315 uint32_t MaybeGetCompositeConstant(
1316     opt::IRContext* ir_context,
1317     const TransformationContext& transformation_context,
1318     const std::vector<uint32_t>& component_ids, uint32_t composite_type_id,
1319     bool is_irrelevant) {
1320   const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
1321   (void)type;  // Make compilers happy in release mode.
1322   assert(IsCompositeType(type) && "|composite_type_id| is invalid");
1323 
1324   for (const auto& inst : ir_context->types_values()) {
1325     if (inst.opcode() == SpvOpConstantComposite &&
1326         inst.type_id() == composite_type_id &&
1327         transformation_context.GetFactManager()->IdIsIrrelevant(
1328             inst.result_id()) == is_irrelevant &&
1329         inst.NumInOperands() == component_ids.size()) {
1330       bool is_match = true;
1331 
1332       for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
1333         if (inst.GetSingleWordInOperand(i) != component_ids[i]) {
1334           is_match = false;
1335           break;
1336         }
1337       }
1338 
1339       if (is_match) {
1340         return inst.result_id();
1341       }
1342     }
1343   }
1344 
1345   return 0;
1346 }
1347 
MaybeGetIntegerConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_signed,bool is_irrelevant)1348 uint32_t MaybeGetIntegerConstant(
1349     opt::IRContext* ir_context,
1350     const TransformationContext& transformation_context,
1351     const std::vector<uint32_t>& words, uint32_t width, bool is_signed,
1352     bool is_irrelevant) {
1353   if (auto type_id = MaybeGetIntegerType(ir_context, width, is_signed)) {
1354     return MaybeGetOpConstant(ir_context, transformation_context, words,
1355                               type_id, is_irrelevant);
1356   }
1357 
1358   return 0;
1359 }
1360 
MaybeGetIntegerConstantFromValueAndType(opt::IRContext * ir_context,uint32_t value,uint32_t int_type_id)1361 uint32_t MaybeGetIntegerConstantFromValueAndType(opt::IRContext* ir_context,
1362                                                  uint32_t value,
1363                                                  uint32_t int_type_id) {
1364   auto int_type_inst = ir_context->get_def_use_mgr()->GetDef(int_type_id);
1365 
1366   assert(int_type_inst && "The given type id must exist.");
1367 
1368   auto int_type = ir_context->get_type_mgr()
1369                       ->GetType(int_type_inst->result_id())
1370                       ->AsInteger();
1371 
1372   assert(int_type && int_type->width() == 32 &&
1373          "The given type id must correspond to an 32-bit integer type.");
1374 
1375   opt::analysis::IntConstant constant(int_type, {value});
1376 
1377   // Check that the constant exists in the module.
1378   if (!ir_context->get_constant_mgr()->FindConstant(&constant)) {
1379     return 0;
1380   }
1381 
1382   return ir_context->get_constant_mgr()
1383       ->GetDefiningInstruction(&constant)
1384       ->result_id();
1385 }
1386 
MaybeGetFloatConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_irrelevant)1387 uint32_t MaybeGetFloatConstant(
1388     opt::IRContext* ir_context,
1389     const TransformationContext& transformation_context,
1390     const std::vector<uint32_t>& words, uint32_t width, bool is_irrelevant) {
1391   if (auto type_id = MaybeGetFloatType(ir_context, width)) {
1392     return MaybeGetOpConstant(ir_context, transformation_context, words,
1393                               type_id, is_irrelevant);
1394   }
1395 
1396   return 0;
1397 }
1398 
MaybeGetBoolConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,bool value,bool is_irrelevant)1399 uint32_t MaybeGetBoolConstant(
1400     opt::IRContext* ir_context,
1401     const TransformationContext& transformation_context, bool value,
1402     bool is_irrelevant) {
1403   if (auto type_id = MaybeGetBoolType(ir_context)) {
1404     for (const auto& inst : ir_context->types_values()) {
1405       if (inst.opcode() == (value ? SpvOpConstantTrue : SpvOpConstantFalse) &&
1406           inst.type_id() == type_id &&
1407           transformation_context.GetFactManager()->IdIsIrrelevant(
1408               inst.result_id()) == is_irrelevant) {
1409         return inst.result_id();
1410       }
1411     }
1412   }
1413 
1414   return 0;
1415 }
1416 
IntToWords(uint64_t value,uint32_t width,bool is_signed)1417 std::vector<uint32_t> IntToWords(uint64_t value, uint32_t width,
1418                                  bool is_signed) {
1419   assert(width <= 64 && "The bit width should not be more than 64 bits");
1420 
1421   // Sign-extend or zero-extend the last |width| bits of |value|, depending on
1422   // |is_signed|.
1423   if (is_signed) {
1424     // Sign-extend by shifting left and then shifting right, interpreting the
1425     // integer as signed.
1426     value = static_cast<int64_t>(value << (64 - width)) >> (64 - width);
1427   } else {
1428     // Zero-extend by shifting left and then shifting right, interpreting the
1429     // integer as unsigned.
1430     value = (value << (64 - width)) >> (64 - width);
1431   }
1432 
1433   std::vector<uint32_t> result;
1434   result.push_back(static_cast<uint32_t>(value));
1435   if (width > 32) {
1436     result.push_back(static_cast<uint32_t>(value >> 32));
1437   }
1438   return result;
1439 }
1440 
TypesAreEqualUpToSign(opt::IRContext * ir_context,uint32_t type1_id,uint32_t type2_id)1441 bool TypesAreEqualUpToSign(opt::IRContext* ir_context, uint32_t type1_id,
1442                            uint32_t type2_id) {
1443   if (type1_id == type2_id) {
1444     return true;
1445   }
1446 
1447   auto type1 = ir_context->get_type_mgr()->GetType(type1_id);
1448   auto type2 = ir_context->get_type_mgr()->GetType(type2_id);
1449 
1450   // Integer scalar types must have the same width
1451   if (type1->AsInteger() && type2->AsInteger()) {
1452     return type1->AsInteger()->width() == type2->AsInteger()->width();
1453   }
1454 
1455   // Integer vector types must have the same number of components and their
1456   // component types must be integers with the same width.
1457   if (type1->AsVector() && type2->AsVector()) {
1458     auto component_type1 = type1->AsVector()->element_type()->AsInteger();
1459     auto component_type2 = type2->AsVector()->element_type()->AsInteger();
1460 
1461     // Only check the component count and width if they are integer.
1462     if (component_type1 && component_type2) {
1463       return type1->AsVector()->element_count() ==
1464                  type2->AsVector()->element_count() &&
1465              component_type1->width() == component_type2->width();
1466     }
1467   }
1468 
1469   // In all other cases, the types cannot be considered equal.
1470   return false;
1471 }
1472 
RepeatedUInt32PairToMap(const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> & data)1473 std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
1474     const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>& data) {
1475   std::map<uint32_t, uint32_t> result;
1476 
1477   for (const auto& entry : data) {
1478     result[entry.first()] = entry.second();
1479   }
1480 
1481   return result;
1482 }
1483 
1484 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t,uint32_t> & data)1485 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
1486   google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> result;
1487 
1488   for (const auto& entry : data) {
1489     protobufs::UInt32Pair pair;
1490     pair.set_first(entry.first);
1491     pair.set_second(entry.second);
1492     *result.Add() = std::move(pair);
1493   }
1494 
1495   return result;
1496 }
1497 
GetLastInsertBeforeInstruction(opt::IRContext * ir_context,uint32_t block_id,SpvOp opcode)1498 opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
1499                                                  uint32_t block_id,
1500                                                  SpvOp opcode) {
1501   // CFG::block uses std::map::at which throws an exception when |block_id| is
1502   // invalid. The error message is unhelpful, though. Thus, we test that
1503   // |block_id| is valid here.
1504   const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
1505   (void)label_inst;  // Make compilers happy in release mode.
1506   assert(label_inst && label_inst->opcode() == SpvOpLabel &&
1507          "|block_id| is invalid");
1508 
1509   auto* block = ir_context->cfg()->block(block_id);
1510   auto it = block->rbegin();
1511   assert(it != block->rend() && "Basic block can't be empty");
1512 
1513   if (block->GetMergeInst()) {
1514     ++it;
1515     assert(it != block->rend() &&
1516            "|block| must have at least two instructions:"
1517            "terminator and a merge instruction");
1518   }
1519 
1520   return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
1521 }
1522 
IdUseCanBeReplaced(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * use_instruction,uint32_t use_in_operand_index)1523 bool IdUseCanBeReplaced(opt::IRContext* ir_context,
1524                         const TransformationContext& transformation_context,
1525                         opt::Instruction* use_instruction,
1526                         uint32_t use_in_operand_index) {
1527   if (spvOpcodeIsAccessChain(use_instruction->opcode()) &&
1528       use_in_operand_index > 0) {
1529     // A replacement for an irrelevant index in OpAccessChain must be clamped
1530     // first.
1531     if (transformation_context.GetFactManager()->IdIsIrrelevant(
1532             use_instruction->GetSingleWordInOperand(use_in_operand_index))) {
1533       return false;
1534     }
1535 
1536     // This is an access chain index.  If the (sub-)object being accessed by the
1537     // given index has struct type then we cannot replace the use, as it needs
1538     // to be an OpConstant.
1539 
1540     // Get the top-level composite type that is being accessed.
1541     auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef(
1542         use_instruction->GetSingleWordInOperand(0));
1543     auto pointer_type =
1544         ir_context->get_type_mgr()->GetType(object_being_accessed->type_id());
1545     assert(pointer_type->AsPointer());
1546     auto composite_type_being_accessed =
1547         pointer_type->AsPointer()->pointee_type();
1548 
1549     // Now walk the access chain, tracking the type of each sub-object of the
1550     // composite that is traversed, until the index of interest is reached.
1551     for (uint32_t index_in_operand = 1; index_in_operand < use_in_operand_index;
1552          index_in_operand++) {
1553       // For vectors, matrices and arrays, getting the type of the sub-object is
1554       // trivial. For the struct case, the sub-object type is field-sensitive,
1555       // and depends on the constant index that is used.
1556       if (composite_type_being_accessed->AsVector()) {
1557         composite_type_being_accessed =
1558             composite_type_being_accessed->AsVector()->element_type();
1559       } else if (composite_type_being_accessed->AsMatrix()) {
1560         composite_type_being_accessed =
1561             composite_type_being_accessed->AsMatrix()->element_type();
1562       } else if (composite_type_being_accessed->AsArray()) {
1563         composite_type_being_accessed =
1564             composite_type_being_accessed->AsArray()->element_type();
1565       } else if (composite_type_being_accessed->AsRuntimeArray()) {
1566         composite_type_being_accessed =
1567             composite_type_being_accessed->AsRuntimeArray()->element_type();
1568       } else {
1569         assert(composite_type_being_accessed->AsStruct());
1570         auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef(
1571             use_instruction->GetSingleWordInOperand(index_in_operand));
1572         assert(constant_index_instruction->opcode() == SpvOpConstant);
1573         uint32_t member_index =
1574             constant_index_instruction->GetSingleWordInOperand(0);
1575         composite_type_being_accessed =
1576             composite_type_being_accessed->AsStruct()
1577                 ->element_types()[member_index];
1578       }
1579     }
1580 
1581     // We have found the composite type being accessed by the index we are
1582     // considering replacing. If it is a struct, then we cannot do the
1583     // replacement as struct indices must be constants.
1584     if (composite_type_being_accessed->AsStruct()) {
1585       return false;
1586     }
1587   }
1588 
1589   if (use_instruction->opcode() == SpvOpFunctionCall &&
1590       use_in_operand_index > 0) {
1591     // This is a function call argument.  It is not allowed to have pointer
1592     // type.
1593 
1594     // Get the definition of the function being called.
1595     auto function = ir_context->get_def_use_mgr()->GetDef(
1596         use_instruction->GetSingleWordInOperand(0));
1597     // From the function definition, get the function type.
1598     auto function_type = ir_context->get_def_use_mgr()->GetDef(
1599         function->GetSingleWordInOperand(1));
1600     // OpTypeFunction's 0-th input operand is the function return type, and the
1601     // function argument types follow. Because the arguments to OpFunctionCall
1602     // start from input operand 1, we can use |use_in_operand_index| to get the
1603     // type associated with this function argument.
1604     auto parameter_type = ir_context->get_type_mgr()->GetType(
1605         function_type->GetSingleWordInOperand(use_in_operand_index));
1606     if (parameter_type->AsPointer()) {
1607       return false;
1608     }
1609   }
1610 
1611   if (use_instruction->opcode() == SpvOpImageTexelPointer &&
1612       use_in_operand_index == 2) {
1613     // The OpImageTexelPointer instruction has a Sample parameter that in some
1614     // situations must be an id for the value 0.  To guard against disrupting
1615     // that requirement, we do not replace this argument to that instruction.
1616     return false;
1617   }
1618 
1619   if (ir_context->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
1620     // With the Shader capability, memory scope and memory semantics operands
1621     // are required to be constants, so they cannot be replaced arbitrarily.
1622     switch (use_instruction->opcode()) {
1623       case SpvOpAtomicLoad:
1624       case SpvOpAtomicStore:
1625       case SpvOpAtomicExchange:
1626       case SpvOpAtomicIIncrement:
1627       case SpvOpAtomicIDecrement:
1628       case SpvOpAtomicIAdd:
1629       case SpvOpAtomicISub:
1630       case SpvOpAtomicSMin:
1631       case SpvOpAtomicUMin:
1632       case SpvOpAtomicSMax:
1633       case SpvOpAtomicUMax:
1634       case SpvOpAtomicAnd:
1635       case SpvOpAtomicOr:
1636       case SpvOpAtomicXor:
1637         if (use_in_operand_index == 1 || use_in_operand_index == 2) {
1638           return false;
1639         }
1640         break;
1641       case SpvOpAtomicCompareExchange:
1642         if (use_in_operand_index == 1 || use_in_operand_index == 2 ||
1643             use_in_operand_index == 3) {
1644           return false;
1645         }
1646         break;
1647       case SpvOpAtomicCompareExchangeWeak:
1648       case SpvOpAtomicFlagTestAndSet:
1649       case SpvOpAtomicFlagClear:
1650       case SpvOpAtomicFAddEXT:
1651         assert(false && "Not allowed with the Shader capability.");
1652       default:
1653         break;
1654     }
1655   }
1656 
1657   return true;
1658 }
1659 
MembersHaveBuiltInDecoration(opt::IRContext * ir_context,uint32_t struct_type_id)1660 bool MembersHaveBuiltInDecoration(opt::IRContext* ir_context,
1661                                   uint32_t struct_type_id) {
1662   const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(struct_type_id);
1663   assert(type_inst && type_inst->opcode() == SpvOpTypeStruct &&
1664          "|struct_type_id| is not a result id of an OpTypeStruct");
1665 
1666   uint32_t builtin_count = 0;
1667   ir_context->get_def_use_mgr()->ForEachUser(
1668       type_inst,
1669       [struct_type_id, &builtin_count](const opt::Instruction* user) {
1670         if (user->opcode() == SpvOpMemberDecorate &&
1671             user->GetSingleWordInOperand(0) == struct_type_id &&
1672             static_cast<SpvDecoration>(user->GetSingleWordInOperand(2)) ==
1673                 SpvDecorationBuiltIn) {
1674           ++builtin_count;
1675         }
1676       });
1677 
1678   assert((builtin_count == 0 || builtin_count == type_inst->NumInOperands()) &&
1679          "The module is invalid: either none or all of the members of "
1680          "|struct_type_id| may be builtin");
1681 
1682   return builtin_count != 0;
1683 }
1684 
HasBlockOrBufferBlockDecoration(opt::IRContext * ir_context,uint32_t id)1685 bool HasBlockOrBufferBlockDecoration(opt::IRContext* ir_context, uint32_t id) {
1686   for (auto decoration : {SpvDecorationBlock, SpvDecorationBufferBlock}) {
1687     if (!ir_context->get_decoration_mgr()->WhileEachDecoration(
1688             id, decoration, [](const opt::Instruction & /*unused*/) -> bool {
1689               return false;
1690             })) {
1691       return true;
1692     }
1693   }
1694   return false;
1695 }
1696 
SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(opt::BasicBlock * block_to_split,opt::Instruction * split_before)1697 bool SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(
1698     opt::BasicBlock* block_to_split, opt::Instruction* split_before) {
1699   std::set<uint32_t> sampled_image_result_ids;
1700   bool before_split = true;
1701 
1702   // Check all the instructions in the block to split.
1703   for (auto& instruction : *block_to_split) {
1704     if (&instruction == &*split_before) {
1705       before_split = false;
1706     }
1707     if (before_split) {
1708       // If the instruction comes before the split and its opcode is
1709       // OpSampledImage, record its result id.
1710       if (instruction.opcode() == SpvOpSampledImage) {
1711         sampled_image_result_ids.insert(instruction.result_id());
1712       }
1713     } else {
1714       // If the instruction comes after the split, check if ids
1715       // corresponding to OpSampledImage instructions defined before the split
1716       // are used, and return true if they are.
1717       if (!instruction.WhileEachInId(
1718               [&sampled_image_result_ids](uint32_t* id) -> bool {
1719                 return !sampled_image_result_ids.count(*id);
1720               })) {
1721         return true;
1722       }
1723     }
1724   }
1725 
1726   // No usage that would be separated from the definition has been found.
1727   return false;
1728 }
1729 
InstructionHasNoSideEffects(const opt::Instruction & instruction)1730 bool InstructionHasNoSideEffects(const opt::Instruction& instruction) {
1731   switch (instruction.opcode()) {
1732     case SpvOpUndef:
1733     case SpvOpAccessChain:
1734     case SpvOpInBoundsAccessChain:
1735     case SpvOpArrayLength:
1736     case SpvOpVectorExtractDynamic:
1737     case SpvOpVectorInsertDynamic:
1738     case SpvOpVectorShuffle:
1739     case SpvOpCompositeConstruct:
1740     case SpvOpCompositeExtract:
1741     case SpvOpCompositeInsert:
1742     case SpvOpCopyObject:
1743     case SpvOpTranspose:
1744     case SpvOpConvertFToU:
1745     case SpvOpConvertFToS:
1746     case SpvOpConvertSToF:
1747     case SpvOpConvertUToF:
1748     case SpvOpUConvert:
1749     case SpvOpSConvert:
1750     case SpvOpFConvert:
1751     case SpvOpQuantizeToF16:
1752     case SpvOpSatConvertSToU:
1753     case SpvOpSatConvertUToS:
1754     case SpvOpBitcast:
1755     case SpvOpSNegate:
1756     case SpvOpFNegate:
1757     case SpvOpIAdd:
1758     case SpvOpFAdd:
1759     case SpvOpISub:
1760     case SpvOpFSub:
1761     case SpvOpIMul:
1762     case SpvOpFMul:
1763     case SpvOpUDiv:
1764     case SpvOpSDiv:
1765     case SpvOpFDiv:
1766     case SpvOpUMod:
1767     case SpvOpSRem:
1768     case SpvOpSMod:
1769     case SpvOpFRem:
1770     case SpvOpFMod:
1771     case SpvOpVectorTimesScalar:
1772     case SpvOpMatrixTimesScalar:
1773     case SpvOpVectorTimesMatrix:
1774     case SpvOpMatrixTimesVector:
1775     case SpvOpMatrixTimesMatrix:
1776     case SpvOpOuterProduct:
1777     case SpvOpDot:
1778     case SpvOpIAddCarry:
1779     case SpvOpISubBorrow:
1780     case SpvOpUMulExtended:
1781     case SpvOpSMulExtended:
1782     case SpvOpAny:
1783     case SpvOpAll:
1784     case SpvOpIsNan:
1785     case SpvOpIsInf:
1786     case SpvOpIsFinite:
1787     case SpvOpIsNormal:
1788     case SpvOpSignBitSet:
1789     case SpvOpLessOrGreater:
1790     case SpvOpOrdered:
1791     case SpvOpUnordered:
1792     case SpvOpLogicalEqual:
1793     case SpvOpLogicalNotEqual:
1794     case SpvOpLogicalOr:
1795     case SpvOpLogicalAnd:
1796     case SpvOpLogicalNot:
1797     case SpvOpSelect:
1798     case SpvOpIEqual:
1799     case SpvOpINotEqual:
1800     case SpvOpUGreaterThan:
1801     case SpvOpSGreaterThan:
1802     case SpvOpUGreaterThanEqual:
1803     case SpvOpSGreaterThanEqual:
1804     case SpvOpULessThan:
1805     case SpvOpSLessThan:
1806     case SpvOpULessThanEqual:
1807     case SpvOpSLessThanEqual:
1808     case SpvOpFOrdEqual:
1809     case SpvOpFUnordEqual:
1810     case SpvOpFOrdNotEqual:
1811     case SpvOpFUnordNotEqual:
1812     case SpvOpFOrdLessThan:
1813     case SpvOpFUnordLessThan:
1814     case SpvOpFOrdGreaterThan:
1815     case SpvOpFUnordGreaterThan:
1816     case SpvOpFOrdLessThanEqual:
1817     case SpvOpFUnordLessThanEqual:
1818     case SpvOpFOrdGreaterThanEqual:
1819     case SpvOpFUnordGreaterThanEqual:
1820     case SpvOpShiftRightLogical:
1821     case SpvOpShiftRightArithmetic:
1822     case SpvOpShiftLeftLogical:
1823     case SpvOpBitwiseOr:
1824     case SpvOpBitwiseXor:
1825     case SpvOpBitwiseAnd:
1826     case SpvOpNot:
1827     case SpvOpBitFieldInsert:
1828     case SpvOpBitFieldSExtract:
1829     case SpvOpBitFieldUExtract:
1830     case SpvOpBitReverse:
1831     case SpvOpBitCount:
1832     case SpvOpCopyLogical:
1833     case SpvOpPhi:
1834     case SpvOpPtrEqual:
1835     case SpvOpPtrNotEqual:
1836       return true;
1837     default:
1838       return false;
1839   }
1840 }
1841 
GetReachableReturnBlocks(opt::IRContext * ir_context,uint32_t function_id)1842 std::set<uint32_t> GetReachableReturnBlocks(opt::IRContext* ir_context,
1843                                             uint32_t function_id) {
1844   auto function = ir_context->GetFunction(function_id);
1845   assert(function && "The function |function_id| must exist.");
1846 
1847   std::set<uint32_t> result;
1848 
1849   ir_context->cfg()->ForEachBlockInPostOrder(function->entry().get(),
1850                                              [&result](opt::BasicBlock* block) {
1851                                                if (block->IsReturn()) {
1852                                                  result.emplace(block->id());
1853                                                }
1854                                              });
1855 
1856   return result;
1857 }
1858 
NewTerminatorPreservesDominationRules(opt::IRContext * ir_context,uint32_t block_id,opt::Instruction new_terminator)1859 bool NewTerminatorPreservesDominationRules(opt::IRContext* ir_context,
1860                                            uint32_t block_id,
1861                                            opt::Instruction new_terminator) {
1862   auto* mutated_block = MaybeFindBlock(ir_context, block_id);
1863   assert(mutated_block && "|block_id| is invalid");
1864 
1865   ChangeTerminatorRAII change_terminator_raii(mutated_block,
1866                                               std::move(new_terminator));
1867   opt::DominatorAnalysis dominator_analysis;
1868   dominator_analysis.InitializeTree(*ir_context->cfg(),
1869                                     mutated_block->GetParent());
1870 
1871   // Check that each dominator appears before each dominated block.
1872   std::unordered_map<uint32_t, size_t> positions;
1873   for (const auto& block : *mutated_block->GetParent()) {
1874     positions[block.id()] = positions.size();
1875   }
1876 
1877   std::queue<uint32_t> q({mutated_block->GetParent()->begin()->id()});
1878   std::unordered_set<uint32_t> visited;
1879   while (!q.empty()) {
1880     auto block = q.front();
1881     q.pop();
1882     visited.insert(block);
1883 
1884     auto success = ir_context->cfg()->block(block)->WhileEachSuccessorLabel(
1885         [&positions, &visited, &dominator_analysis, block, &q](uint32_t id) {
1886           if (id == block) {
1887             // Handle the case when loop header and continue target are the same
1888             // block.
1889             return true;
1890           }
1891 
1892           if (dominator_analysis.Dominates(block, id) &&
1893               positions[block] > positions[id]) {
1894             // |block| dominates |id| but appears after |id| - violates
1895             // domination rules.
1896             return false;
1897           }
1898 
1899           if (!visited.count(id)) {
1900             q.push(id);
1901           }
1902 
1903           return true;
1904         });
1905 
1906     if (!success) {
1907       return false;
1908     }
1909   }
1910 
1911   // For each instruction in the |block->GetParent()| function check whether
1912   // all its dependencies satisfy domination rules (i.e. all id operands
1913   // dominate that instruction).
1914   for (const auto& block : *mutated_block->GetParent()) {
1915     if (!ir_context->IsReachable(block)) {
1916       // If some block is not reachable then we don't need to worry about the
1917       // preservation of domination rules for its instructions.
1918       continue;
1919     }
1920 
1921     for (const auto& inst : block) {
1922       for (uint32_t i = 0; i < inst.NumInOperands();
1923            i += inst.opcode() == SpvOpPhi ? 2 : 1) {
1924         const auto& operand = inst.GetInOperand(i);
1925         if (!spvIsInIdType(operand.type)) {
1926           continue;
1927         }
1928 
1929         if (MaybeFindBlock(ir_context, operand.words[0])) {
1930           // Ignore operands that refer to OpLabel instructions.
1931           continue;
1932         }
1933 
1934         const auto* dependency_block =
1935             ir_context->get_instr_block(operand.words[0]);
1936         if (!dependency_block) {
1937           // A global instruction always dominates all instructions in any
1938           // function.
1939           continue;
1940         }
1941 
1942         auto domination_target_id = inst.opcode() == SpvOpPhi
1943                                         ? inst.GetSingleWordInOperand(i + 1)
1944                                         : block.id();
1945 
1946         if (!dominator_analysis.Dominates(dependency_block->id(),
1947                                           domination_target_id)) {
1948           return false;
1949         }
1950       }
1951     }
1952   }
1953 
1954   return true;
1955 }
1956 
GetFunctionIterator(opt::IRContext * ir_context,uint32_t function_id)1957 opt::Module::iterator GetFunctionIterator(opt::IRContext* ir_context,
1958                                           uint32_t function_id) {
1959   return std::find_if(ir_context->module()->begin(),
1960                       ir_context->module()->end(),
1961                       [function_id](const opt::Function& f) {
1962                         return f.result_id() == function_id;
1963                       });
1964 }
1965 
1966 }  // namespace fuzzerutil
1967 }  // namespace fuzz
1968 }  // namespace spvtools
1969