• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/fuzzer_util.h"
16 
17 #include <algorithm>
18 #include <unordered_set>
19 
20 #include "source/opt/build_module.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 
25 namespace fuzzerutil {
26 namespace {
27 
28 // A utility class that uses RAII to change and restore the terminator
29 // instruction of the |block|.
30 class ChangeTerminatorRAII {
31  public:
ChangeTerminatorRAII(opt::BasicBlock * block,opt::Instruction new_terminator)32   explicit ChangeTerminatorRAII(opt::BasicBlock* block,
33                                 opt::Instruction new_terminator)
34       : block_(block), old_terminator_(std::move(*block->terminator())) {
35     *block_->terminator() = std::move(new_terminator);
36   }
37 
~ChangeTerminatorRAII()38   ~ChangeTerminatorRAII() {
39     *block_->terminator() = std::move(old_terminator_);
40   }
41 
42  private:
43   opt::BasicBlock* block_;
44   opt::Instruction old_terminator_;
45 };
46 
MaybeGetOpConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t type_id,bool is_irrelevant)47 uint32_t MaybeGetOpConstant(opt::IRContext* ir_context,
48                             const TransformationContext& transformation_context,
49                             const std::vector<uint32_t>& words,
50                             uint32_t type_id, bool is_irrelevant) {
51   for (const auto& inst : ir_context->types_values()) {
52     if (inst.opcode() == spv::Op::OpConstant && inst.type_id() == type_id &&
53         inst.GetInOperand(0).words == words &&
54         transformation_context.GetFactManager()->IdIsIrrelevant(
55             inst.result_id()) == is_irrelevant) {
56       return inst.result_id();
57     }
58   }
59 
60   return 0;
61 }
62 
63 }  // namespace
64 
65 const spvtools::MessageConsumer kSilentMessageConsumer =
66     [](spv_message_level_t, const char*, const spv_position_t&,
__anon3879f0950202(spv_message_level_t, const char*, const spv_position_t&, const char*) 67        const char*) -> void {};
68 
BuildIRContext(spv_target_env target_env,const spvtools::MessageConsumer & message_consumer,const std::vector<uint32_t> & binary_in,spv_validator_options validator_options,std::unique_ptr<spvtools::opt::IRContext> * ir_context)69 bool BuildIRContext(spv_target_env target_env,
70                     const spvtools::MessageConsumer& message_consumer,
71                     const std::vector<uint32_t>& binary_in,
72                     spv_validator_options validator_options,
73                     std::unique_ptr<spvtools::opt::IRContext>* ir_context) {
74   SpirvTools tools(target_env);
75   tools.SetMessageConsumer(message_consumer);
76   if (!tools.IsValid()) {
77     message_consumer(SPV_MSG_ERROR, nullptr, {},
78                      "Failed to create SPIRV-Tools interface; stopping.");
79     return false;
80   }
81 
82   // Initial binary should be valid.
83   if (!tools.Validate(binary_in.data(), binary_in.size(), validator_options)) {
84     message_consumer(SPV_MSG_ERROR, nullptr, {},
85                      "Initial binary is invalid; stopping.");
86     return false;
87   }
88 
89   // Build the module from the input binary.
90   auto result = BuildModule(target_env, message_consumer, binary_in.data(),
91                             binary_in.size());
92   assert(result && "IRContext must be valid");
93   *ir_context = std::move(result);
94   return true;
95 }
96 
IsFreshId(opt::IRContext * context,uint32_t id)97 bool IsFreshId(opt::IRContext* context, uint32_t id) {
98   return !context->get_def_use_mgr()->GetDef(id);
99 }
100 
UpdateModuleIdBound(opt::IRContext * context,uint32_t id)101 void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
102   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
103   //  case where the maximum id bound is reached.
104   context->module()->SetIdBound(
105       std::max(context->module()->id_bound(), id + 1));
106 }
107 
MaybeFindBlock(opt::IRContext * context,uint32_t maybe_block_id)108 opt::BasicBlock* MaybeFindBlock(opt::IRContext* context,
109                                 uint32_t maybe_block_id) {
110   auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
111   if (inst == nullptr) {
112     // No instruction defining this id was found.
113     return nullptr;
114   }
115   if (inst->opcode() != spv::Op::OpLabel) {
116     // The instruction defining the id is not a label, so it cannot be a block
117     // id.
118     return nullptr;
119   }
120   return context->cfg()->block(maybe_block_id);
121 }
122 
PhiIdsOkForNewEdge(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)123 bool PhiIdsOkForNewEdge(
124     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
125     const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
126   if (bb_from->IsSuccessor(bb_to)) {
127     // There is already an edge from |from_block| to |to_block|, so there is
128     // no need to extend OpPhi instructions.  Do not allow phi ids to be
129     // present. This might turn out to be too strict; perhaps it would be OK
130     // just to ignore the ids in this case.
131     return phi_ids.empty();
132   }
133   // The edge would add a previously non-existent edge from |from_block| to
134   // |to_block|, so we go through the given phi ids and check that they exactly
135   // match the OpPhi instructions in |to_block|.
136   uint32_t phi_index = 0;
137   // An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
138   // makes sense here because we need to increment |phi_index| for each OpPhi
139   // instruction.
140   for (auto& inst : *bb_to) {
141     if (inst.opcode() != spv::Op::OpPhi) {
142       // The OpPhi instructions all occur at the start of the block; if we find
143       // a non-OpPhi then we have seen them all.
144       break;
145     }
146     if (phi_index == static_cast<uint32_t>(phi_ids.size())) {
147       // Not enough phi ids have been provided to account for the OpPhi
148       // instructions.
149       return false;
150     }
151     // Look for an instruction defining the next phi id.
152     opt::Instruction* phi_extension =
153         context->get_def_use_mgr()->GetDef(phi_ids[phi_index]);
154     if (!phi_extension) {
155       // The id given to extend this OpPhi does not exist.
156       return false;
157     }
158     if (phi_extension->type_id() != inst.type_id()) {
159       // The instruction given to extend this OpPhi either does not have a type
160       // or its type does not match that of the OpPhi.
161       return false;
162     }
163 
164     if (context->get_instr_block(phi_extension)) {
165       // The instruction defining the phi id has an associated block (i.e., it
166       // is not a global value).  Check whether its definition dominates the
167       // exit of |from_block|.
168       auto dominator_analysis =
169           context->GetDominatorAnalysis(bb_from->GetParent());
170       if (!dominator_analysis->Dominates(phi_extension,
171                                          bb_from->terminator())) {
172         // The given id is no good as its definition does not dominate the exit
173         // of |from_block|
174         return false;
175       }
176     }
177     phi_index++;
178   }
179   // We allow some of the ids provided for extending OpPhi instructions to be
180   // unused.  Their presence does no harm, and requiring a perfect match may
181   // make transformations less likely to cleanly apply.
182   return true;
183 }
184 
CreateUnreachableEdgeInstruction(opt::IRContext * ir_context,uint32_t bb_from_id,uint32_t bb_to_id,uint32_t bool_id)185 opt::Instruction CreateUnreachableEdgeInstruction(opt::IRContext* ir_context,
186                                                   uint32_t bb_from_id,
187                                                   uint32_t bb_to_id,
188                                                   uint32_t bool_id) {
189   const auto* bb_from = MaybeFindBlock(ir_context, bb_from_id);
190   assert(bb_from && "|bb_from_id| is invalid");
191   assert(MaybeFindBlock(ir_context, bb_to_id) && "|bb_to_id| is invalid");
192   assert(bb_from->terminator()->opcode() == spv::Op::OpBranch &&
193          "Precondition on terminator of bb_from is not satisfied");
194 
195   // Get the id of the boolean constant to be used as the condition.
196   auto condition_inst = ir_context->get_def_use_mgr()->GetDef(bool_id);
197   assert(condition_inst &&
198          (condition_inst->opcode() == spv::Op::OpConstantTrue ||
199           condition_inst->opcode() == spv::Op::OpConstantFalse) &&
200          "|bool_id| is invalid");
201 
202   auto condition_value = condition_inst->opcode() == spv::Op::OpConstantTrue;
203   auto successor_id = bb_from->terminator()->GetSingleWordInOperand(0);
204 
205   // Add the dead branch, by turning OpBranch into OpBranchConditional, and
206   // ordering the targets depending on whether the given boolean corresponds to
207   // true or false.
208   return opt::Instruction(
209       ir_context, spv::Op::OpBranchConditional, 0, 0,
210       {{SPV_OPERAND_TYPE_ID, {bool_id}},
211        {SPV_OPERAND_TYPE_ID, {condition_value ? successor_id : bb_to_id}},
212        {SPV_OPERAND_TYPE_ID, {condition_value ? bb_to_id : successor_id}}});
213 }
214 
AddUnreachableEdgeAndUpdateOpPhis(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,uint32_t bool_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)215 void AddUnreachableEdgeAndUpdateOpPhis(
216     opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
217     uint32_t bool_id,
218     const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
219   assert(PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) &&
220          "Precondition on phi_ids is not satisfied");
221 
222   const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
223   *bb_from->terminator() = CreateUnreachableEdgeInstruction(
224       context, bb_from->id(), bb_to->id(), bool_id);
225 
226   // Update OpPhi instructions in the target block if this branch adds a
227   // previously non-existent edge from source to target.
228   if (!from_to_edge_already_exists) {
229     uint32_t phi_index = 0;
230     for (auto& inst : *bb_to) {
231       if (inst.opcode() != spv::Op::OpPhi) {
232         break;
233       }
234       assert(phi_index < static_cast<uint32_t>(phi_ids.size()) &&
235              "There should be at least one phi id per OpPhi instruction.");
236       inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}});
237       inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}});
238       phi_index++;
239     }
240   }
241 }
242 
BlockIsBackEdge(opt::IRContext * context,uint32_t block_id,uint32_t loop_header_id)243 bool BlockIsBackEdge(opt::IRContext* context, uint32_t block_id,
244                      uint32_t loop_header_id) {
245   auto block = context->cfg()->block(block_id);
246   auto loop_header = context->cfg()->block(loop_header_id);
247 
248   // |block| and |loop_header| must be defined, |loop_header| must be in fact
249   // loop header and |block| must branch to it.
250   if (!(block && loop_header && loop_header->IsLoopHeader() &&
251         block->IsSuccessor(loop_header))) {
252     return false;
253   }
254 
255   // |block| must be reachable and be dominated by |loop_header|.
256   opt::DominatorAnalysis* dominator_analysis =
257       context->GetDominatorAnalysis(loop_header->GetParent());
258   return context->IsReachable(*block) &&
259          dominator_analysis->Dominates(loop_header, block);
260 }
261 
BlockIsInLoopContinueConstruct(opt::IRContext * context,uint32_t block_id,uint32_t maybe_loop_header_id)262 bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
263                                     uint32_t maybe_loop_header_id) {
264   // We deem a block to be part of a loop's continue construct if the loop's
265   // continue target dominates the block.
266   auto containing_construct_block = context->cfg()->block(maybe_loop_header_id);
267   if (containing_construct_block->IsLoopHeader()) {
268     auto continue_target = containing_construct_block->ContinueBlockId();
269     if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
270             ->Dominates(continue_target, block_id)) {
271       return true;
272     }
273   }
274   return false;
275 }
276 
GetIteratorForInstruction(opt::BasicBlock * block,const opt::Instruction * inst)277 opt::BasicBlock::iterator GetIteratorForInstruction(
278     opt::BasicBlock* block, const opt::Instruction* inst) {
279   for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
280     if (inst == &*inst_it) {
281       return inst_it;
282     }
283   }
284   return block->end();
285 }
286 
CanInsertOpcodeBeforeInstruction(spv::Op opcode,const opt::BasicBlock::iterator & instruction_in_block)287 bool CanInsertOpcodeBeforeInstruction(
288     spv::Op opcode, const opt::BasicBlock::iterator& instruction_in_block) {
289   if (instruction_in_block->PreviousNode() &&
290       (instruction_in_block->PreviousNode()->opcode() == spv::Op::OpLoopMerge ||
291        instruction_in_block->PreviousNode()->opcode() ==
292            spv::Op::OpSelectionMerge)) {
293     // We cannot insert directly after a merge instruction.
294     return false;
295   }
296   if (opcode != spv::Op::OpVariable &&
297       instruction_in_block->opcode() == spv::Op::OpVariable) {
298     // We cannot insert a non-OpVariable instruction directly before a
299     // variable; variables in a function must be contiguous in the entry block.
300     return false;
301   }
302   // We cannot insert a non-OpPhi instruction directly before an OpPhi, because
303   // OpPhi instructions need to be contiguous at the start of a block.
304   return opcode == spv::Op::OpPhi ||
305          instruction_in_block->opcode() != spv::Op::OpPhi;
306 }
307 
CanMakeSynonymOf(opt::IRContext * ir_context,const TransformationContext & transformation_context,const opt::Instruction & inst)308 bool CanMakeSynonymOf(opt::IRContext* ir_context,
309                       const TransformationContext& transformation_context,
310                       const opt::Instruction& inst) {
311   if (inst.opcode() == spv::Op::OpSampledImage) {
312     // The SPIR-V data rules say that only very specific instructions may
313     // may consume the result id of an OpSampledImage, and this excludes the
314     // instructions that are used for making synonyms.
315     return false;
316   }
317   if (!inst.HasResultId()) {
318     // We can only make a synonym of an instruction that generates an id.
319     return false;
320   }
321   if (transformation_context.GetFactManager()->IdIsIrrelevant(
322           inst.result_id())) {
323     // An irrelevant id can't be a synonym of anything.
324     return false;
325   }
326   if (!inst.type_id()) {
327     // We can only make a synonym of an instruction that has a type.
328     return false;
329   }
330   auto type_inst = ir_context->get_def_use_mgr()->GetDef(inst.type_id());
331   if (type_inst->opcode() == spv::Op::OpTypeVoid) {
332     // We only make synonyms of instructions that define objects, and an object
333     // cannot have void type.
334     return false;
335   }
336   if (type_inst->opcode() == spv::Op::OpTypePointer) {
337     switch (inst.opcode()) {
338       case spv::Op::OpConstantNull:
339       case spv::Op::OpUndef:
340         // We disallow making synonyms of null or undefined pointers.  This is
341         // to provide the property that if the original shader exhibited no bad
342         // pointer accesses, the transformed shader will not either.
343         return false;
344       default:
345         break;
346     }
347   }
348 
349   // We do not make synonyms of objects that have decorations: if the synonym is
350   // not decorated analogously, using the original object vs. its synonymous
351   // form may not be equivalent.
352   return ir_context->get_decoration_mgr()
353       ->GetDecorationsFor(inst.result_id(), true)
354       .empty();
355 }
356 
IsCompositeType(const opt::analysis::Type * type)357 bool IsCompositeType(const opt::analysis::Type* type) {
358   return type && (type->AsArray() || type->AsMatrix() || type->AsStruct() ||
359                   type->AsVector());
360 }
361 
RepeatedFieldToVector(const google::protobuf::RepeatedField<uint32_t> & repeated_field)362 std::vector<uint32_t> RepeatedFieldToVector(
363     const google::protobuf::RepeatedField<uint32_t>& repeated_field) {
364   std::vector<uint32_t> result;
365   for (auto i : repeated_field) {
366     result.push_back(i);
367   }
368   return result;
369 }
370 
WalkOneCompositeTypeIndex(opt::IRContext * context,uint32_t base_object_type_id,uint32_t index)371 uint32_t WalkOneCompositeTypeIndex(opt::IRContext* context,
372                                    uint32_t base_object_type_id,
373                                    uint32_t index) {
374   auto should_be_composite_type =
375       context->get_def_use_mgr()->GetDef(base_object_type_id);
376   assert(should_be_composite_type && "The type should exist.");
377   switch (should_be_composite_type->opcode()) {
378     case spv::Op::OpTypeArray: {
379       auto array_length = GetArraySize(*should_be_composite_type, context);
380       if (array_length == 0 || index >= array_length) {
381         return 0;
382       }
383       return should_be_composite_type->GetSingleWordInOperand(0);
384     }
385     case spv::Op::OpTypeMatrix:
386     case spv::Op::OpTypeVector: {
387       auto count = should_be_composite_type->GetSingleWordInOperand(1);
388       if (index >= count) {
389         return 0;
390       }
391       return should_be_composite_type->GetSingleWordInOperand(0);
392     }
393     case spv::Op::OpTypeStruct: {
394       if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
395         return 0;
396       }
397       return should_be_composite_type->GetSingleWordInOperand(index);
398     }
399     default:
400       return 0;
401   }
402 }
403 
WalkCompositeTypeIndices(opt::IRContext * context,uint32_t base_object_type_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & indices)404 uint32_t WalkCompositeTypeIndices(
405     opt::IRContext* context, uint32_t base_object_type_id,
406     const google::protobuf::RepeatedField<google::protobuf::uint32>& indices) {
407   uint32_t sub_object_type_id = base_object_type_id;
408   for (auto index : indices) {
409     sub_object_type_id =
410         WalkOneCompositeTypeIndex(context, sub_object_type_id, index);
411     if (!sub_object_type_id) {
412       return 0;
413     }
414   }
415   return sub_object_type_id;
416 }
417 
GetNumberOfStructMembers(const opt::Instruction & struct_type_instruction)418 uint32_t GetNumberOfStructMembers(
419     const opt::Instruction& struct_type_instruction) {
420   assert(struct_type_instruction.opcode() == spv::Op::OpTypeStruct &&
421          "An OpTypeStruct instruction is required here.");
422   return struct_type_instruction.NumInOperands();
423 }
424 
GetArraySize(const opt::Instruction & array_type_instruction,opt::IRContext * context)425 uint32_t GetArraySize(const opt::Instruction& array_type_instruction,
426                       opt::IRContext* context) {
427   auto array_length_constant =
428       context->get_constant_mgr()
429           ->GetConstantFromInst(context->get_def_use_mgr()->GetDef(
430               array_type_instruction.GetSingleWordInOperand(1)))
431           ->AsIntConstant();
432   if (array_length_constant->words().size() != 1) {
433     return 0;
434   }
435   return array_length_constant->GetU32();
436 }
437 
GetBoundForCompositeIndex(const opt::Instruction & composite_type_inst,opt::IRContext * ir_context)438 uint32_t GetBoundForCompositeIndex(const opt::Instruction& composite_type_inst,
439                                    opt::IRContext* ir_context) {
440   switch (composite_type_inst.opcode()) {
441     case spv::Op::OpTypeArray:
442       return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
443     case spv::Op::OpTypeMatrix:
444     case spv::Op::OpTypeVector:
445       return composite_type_inst.GetSingleWordInOperand(1);
446     case spv::Op::OpTypeStruct: {
447       return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
448     }
449     case spv::Op::OpTypeRuntimeArray:
450       assert(false &&
451              "GetBoundForCompositeIndex should not be invoked with an "
452              "OpTypeRuntimeArray, which does not have a static bound.");
453       return 0;
454     default:
455       assert(false && "Unknown composite type.");
456       return 0;
457   }
458 }
459 
GetMemorySemanticsForStorageClass(spv::StorageClass storage_class)460 spv::MemorySemanticsMask GetMemorySemanticsForStorageClass(
461     spv::StorageClass storage_class) {
462   switch (storage_class) {
463     case spv::StorageClass::Workgroup:
464       return spv::MemorySemanticsMask::WorkgroupMemory;
465 
466     case spv::StorageClass::StorageBuffer:
467     case spv::StorageClass::PhysicalStorageBuffer:
468       return spv::MemorySemanticsMask::UniformMemory;
469 
470     case spv::StorageClass::CrossWorkgroup:
471       return spv::MemorySemanticsMask::CrossWorkgroupMemory;
472 
473     case spv::StorageClass::AtomicCounter:
474       return spv::MemorySemanticsMask::AtomicCounterMemory;
475 
476     case spv::StorageClass::Image:
477       return spv::MemorySemanticsMask::ImageMemory;
478 
479     default:
480       return spv::MemorySemanticsMask::MaskNone;
481   }
482 }
483 
IsValid(const opt::IRContext * context,spv_validator_options validator_options,MessageConsumer consumer)484 bool IsValid(const opt::IRContext* context,
485              spv_validator_options validator_options,
486              MessageConsumer consumer) {
487   std::vector<uint32_t> binary;
488   context->module()->ToBinary(&binary, false);
489   SpirvTools tools(context->grammar().target_env());
490   tools.SetMessageConsumer(std::move(consumer));
491   return tools.Validate(binary.data(), binary.size(), validator_options);
492 }
493 
IsValidAndWellFormed(const opt::IRContext * ir_context,spv_validator_options validator_options,MessageConsumer consumer)494 bool IsValidAndWellFormed(const opt::IRContext* ir_context,
495                           spv_validator_options validator_options,
496                           MessageConsumer consumer) {
497   if (!IsValid(ir_context, validator_options, consumer)) {
498     // Expression to dump |ir_context| to /data/temp/shader.spv:
499     //    DumpShader(ir_context, "/data/temp/shader.spv")
500     consumer(SPV_MSG_INFO, nullptr, {},
501              "Module is invalid (set a breakpoint to inspect).");
502     return false;
503   }
504   // Check that all blocks in the module have appropriate parent functions.
505   for (auto& function : *ir_context->module()) {
506     for (auto& block : function) {
507       if (block.GetParent() == nullptr) {
508         std::stringstream ss;
509         ss << "Block " << block.id() << " has no parent; its parent should be "
510            << function.result_id() << " (set a breakpoint to inspect).";
511         consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
512         return false;
513       }
514       if (block.GetParent() != &function) {
515         std::stringstream ss;
516         ss << "Block " << block.id() << " should have parent "
517            << function.result_id() << " but instead has parent "
518            << block.GetParent() << " (set a breakpoint to inspect).";
519         consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
520         return false;
521       }
522     }
523   }
524 
525   // Check that all instructions have distinct unique ids.  We map each unique
526   // id to the first instruction it is observed to be associated with so that
527   // if we encounter a duplicate we have access to the previous instruction -
528   // this is a useful aid to debugging.
529   std::unordered_map<uint32_t, opt::Instruction*> unique_ids;
530   bool found_duplicate = false;
531   ir_context->module()->ForEachInst([&consumer, &found_duplicate, ir_context,
532                                      &unique_ids](opt::Instruction* inst) {
533     (void)ir_context;  // Only used in an assertion; keep release-mode compilers
534                        // happy.
535     assert(inst->context() == ir_context &&
536            "Instruction has wrong IR context.");
537     if (unique_ids.count(inst->unique_id()) != 0) {
538       consumer(SPV_MSG_INFO, nullptr, {},
539                "Two instructions have the same unique id (set a breakpoint to "
540                "inspect).");
541       found_duplicate = true;
542     }
543     unique_ids.insert({inst->unique_id(), inst});
544   });
545   return !found_duplicate;
546 }
547 
CloneIRContext(opt::IRContext * context)548 std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) {
549   std::vector<uint32_t> binary;
550   context->module()->ToBinary(&binary, false);
551   return BuildModule(context->grammar().target_env(), nullptr, binary.data(),
552                      binary.size());
553 }
554 
IsNonFunctionTypeId(opt::IRContext * ir_context,uint32_t id)555 bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id) {
556   auto type = ir_context->get_type_mgr()->GetType(id);
557   return type && !type->AsFunction();
558 }
559 
IsMergeOrContinue(opt::IRContext * ir_context,uint32_t block_id)560 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) {
561   bool result = false;
562   ir_context->get_def_use_mgr()->WhileEachUse(
563       block_id,
564       [&result](const opt::Instruction* use_instruction,
565                 uint32_t /*unused*/) -> bool {
566         switch (use_instruction->opcode()) {
567           case spv::Op::OpLoopMerge:
568           case spv::Op::OpSelectionMerge:
569             result = true;
570             return false;
571           default:
572             return true;
573         }
574       });
575   return result;
576 }
577 
GetLoopFromMergeBlock(opt::IRContext * ir_context,uint32_t merge_block_id)578 uint32_t GetLoopFromMergeBlock(opt::IRContext* ir_context,
579                                uint32_t merge_block_id) {
580   uint32_t result = 0;
581   ir_context->get_def_use_mgr()->WhileEachUse(
582       merge_block_id,
583       [ir_context, &result](opt::Instruction* use_instruction,
584                             uint32_t use_index) -> bool {
585         switch (use_instruction->opcode()) {
586           case spv::Op::OpLoopMerge:
587             // The merge block operand is the first operand in OpLoopMerge.
588             if (use_index == 0) {
589               result = ir_context->get_instr_block(use_instruction)->id();
590               return false;
591             }
592             return true;
593           default:
594             return true;
595         }
596       });
597   return result;
598 }
599 
FindFunctionType(opt::IRContext * ir_context,const std::vector<uint32_t> & type_ids)600 uint32_t FindFunctionType(opt::IRContext* ir_context,
601                           const std::vector<uint32_t>& type_ids) {
602   // Look through the existing types for a match.
603   for (auto& type_or_value : ir_context->types_values()) {
604     if (type_or_value.opcode() != spv::Op::OpTypeFunction) {
605       // We are only interested in function types.
606       continue;
607     }
608     if (type_or_value.NumInOperands() != type_ids.size()) {
609       // Not a match: different numbers of arguments.
610       continue;
611     }
612     // Check whether the return type and argument types match.
613     bool input_operands_match = true;
614     for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
615       if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
616         input_operands_match = false;
617         break;
618       }
619     }
620     if (input_operands_match) {
621       // Everything matches.
622       return type_or_value.result_id();
623     }
624   }
625   // No match was found.
626   return 0;
627 }
628 
GetFunctionType(opt::IRContext * context,const opt::Function * function)629 opt::Instruction* GetFunctionType(opt::IRContext* context,
630                                   const opt::Function* function) {
631   uint32_t type_id = function->DefInst().GetSingleWordInOperand(1);
632   return context->get_def_use_mgr()->GetDef(type_id);
633 }
634 
FindFunction(opt::IRContext * ir_context,uint32_t function_id)635 opt::Function* FindFunction(opt::IRContext* ir_context, uint32_t function_id) {
636   for (auto& function : *ir_context->module()) {
637     if (function.result_id() == function_id) {
638       return &function;
639     }
640   }
641   return nullptr;
642 }
643 
FunctionContainsOpKillOrUnreachable(const opt::Function & function)644 bool FunctionContainsOpKillOrUnreachable(const opt::Function& function) {
645   for (auto& block : function) {
646     if (block.terminator()->opcode() == spv::Op::OpKill ||
647         block.terminator()->opcode() == spv::Op::OpUnreachable) {
648       return true;
649     }
650   }
651   return false;
652 }
653 
FunctionIsEntryPoint(opt::IRContext * context,uint32_t function_id)654 bool FunctionIsEntryPoint(opt::IRContext* context, uint32_t function_id) {
655   for (auto& entry_point : context->module()->entry_points()) {
656     if (entry_point.GetSingleWordInOperand(1) == function_id) {
657       return true;
658     }
659   }
660   return false;
661 }
662 
IdIsAvailableAtUse(opt::IRContext * context,opt::Instruction * use_instruction,uint32_t use_input_operand_index,uint32_t id)663 bool IdIsAvailableAtUse(opt::IRContext* context,
664                         opt::Instruction* use_instruction,
665                         uint32_t use_input_operand_index, uint32_t id) {
666   assert(context->get_instr_block(use_instruction) &&
667          "|use_instruction| must be in a basic block");
668 
669   auto defining_instruction = context->get_def_use_mgr()->GetDef(id);
670   auto enclosing_function =
671       context->get_instr_block(use_instruction)->GetParent();
672   // If the id a function parameter, it needs to be associated with the
673   // function containing the use.
674   if (defining_instruction->opcode() == spv::Op::OpFunctionParameter) {
675     return InstructionIsFunctionParameter(defining_instruction,
676                                           enclosing_function);
677   }
678   if (!context->get_instr_block(id)) {
679     // The id must be at global scope.
680     return true;
681   }
682   if (defining_instruction == use_instruction) {
683     // It is not OK for a definition to use itself.
684     return false;
685   }
686   if (!context->IsReachable(*context->get_instr_block(use_instruction)) ||
687       !context->IsReachable(*context->get_instr_block(id))) {
688     // Skip unreachable blocks.
689     return false;
690   }
691   auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function);
692   if (use_instruction->opcode() == spv::Op::OpPhi) {
693     // In the case where the use is an operand to OpPhi, it is actually the
694     // *parent* block associated with the operand that must be dominated by
695     // the synonym.
696     auto parent_block =
697         use_instruction->GetSingleWordInOperand(use_input_operand_index + 1);
698     return dominator_analysis->Dominates(
699         context->get_instr_block(defining_instruction)->id(), parent_block);
700   }
701   return dominator_analysis->Dominates(defining_instruction, use_instruction);
702 }
703 
IdIsAvailableBeforeInstruction(opt::IRContext * context,opt::Instruction * instruction,uint32_t id)704 bool IdIsAvailableBeforeInstruction(opt::IRContext* context,
705                                     opt::Instruction* instruction,
706                                     uint32_t id) {
707   assert(context->get_instr_block(instruction) &&
708          "|instruction| must be in a basic block");
709 
710   auto id_definition = context->get_def_use_mgr()->GetDef(id);
711   auto function_enclosing_instruction =
712       context->get_instr_block(instruction)->GetParent();
713   // If the id a function parameter, it needs to be associated with the
714   // function containing the instruction.
715   if (id_definition->opcode() == spv::Op::OpFunctionParameter) {
716     return InstructionIsFunctionParameter(id_definition,
717                                           function_enclosing_instruction);
718   }
719   if (!context->get_instr_block(id)) {
720     // The id is at global scope.
721     return true;
722   }
723   if (id_definition == instruction) {
724     // The instruction is not available right before its own definition.
725     return false;
726   }
727   const auto* dominator_analysis =
728       context->GetDominatorAnalysis(function_enclosing_instruction);
729   if (context->IsReachable(*context->get_instr_block(instruction)) &&
730       context->IsReachable(*context->get_instr_block(id)) &&
731       dominator_analysis->Dominates(id_definition, instruction)) {
732     // The id's definition dominates the instruction, and both the definition
733     // and the instruction are in reachable blocks, thus the id is available at
734     // the instruction.
735     return true;
736   }
737   if (id_definition->opcode() == spv::Op::OpVariable &&
738       function_enclosing_instruction ==
739           context->get_instr_block(id)->GetParent()) {
740     assert(!context->IsReachable(*context->get_instr_block(instruction)) &&
741            "If the instruction were in a reachable block we should already "
742            "have returned true.");
743     // The id is a variable and it is in the same function as |instruction|.
744     // This is OK despite |instruction| being unreachable.
745     return true;
746   }
747   return false;
748 }
749 
InstructionIsFunctionParameter(opt::Instruction * instruction,opt::Function * function)750 bool InstructionIsFunctionParameter(opt::Instruction* instruction,
751                                     opt::Function* function) {
752   if (instruction->opcode() != spv::Op::OpFunctionParameter) {
753     return false;
754   }
755   bool found_parameter = false;
756   function->ForEachParam(
757       [instruction, &found_parameter](opt::Instruction* param) {
758         if (param == instruction) {
759           found_parameter = true;
760         }
761       });
762   return found_parameter;
763 }
764 
GetTypeId(opt::IRContext * context,uint32_t result_id)765 uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
766   const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
767   assert(inst && "|result_id| is invalid");
768   return inst->type_id();
769 }
770 
GetPointeeTypeIdFromPointerType(opt::Instruction * pointer_type_inst)771 uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
772   assert(pointer_type_inst &&
773          pointer_type_inst->opcode() == spv::Op::OpTypePointer &&
774          "Precondition: |pointer_type_inst| must be OpTypePointer.");
775   return pointer_type_inst->GetSingleWordInOperand(1);
776 }
777 
GetPointeeTypeIdFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)778 uint32_t GetPointeeTypeIdFromPointerType(opt::IRContext* context,
779                                          uint32_t pointer_type_id) {
780   return GetPointeeTypeIdFromPointerType(
781       context->get_def_use_mgr()->GetDef(pointer_type_id));
782 }
783 
GetStorageClassFromPointerType(opt::Instruction * pointer_type_inst)784 spv::StorageClass GetStorageClassFromPointerType(
785     opt::Instruction* pointer_type_inst) {
786   assert(pointer_type_inst &&
787          pointer_type_inst->opcode() == spv::Op::OpTypePointer &&
788          "Precondition: |pointer_type_inst| must be OpTypePointer.");
789   return static_cast<spv::StorageClass>(
790       pointer_type_inst->GetSingleWordInOperand(0));
791 }
792 
GetStorageClassFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)793 spv::StorageClass GetStorageClassFromPointerType(opt::IRContext* context,
794                                                  uint32_t pointer_type_id) {
795   return GetStorageClassFromPointerType(
796       context->get_def_use_mgr()->GetDef(pointer_type_id));
797 }
798 
MaybeGetPointerType(opt::IRContext * context,uint32_t pointee_type_id,spv::StorageClass storage_class)799 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
800                              spv::StorageClass storage_class) {
801   for (auto& inst : context->types_values()) {
802     switch (inst.opcode()) {
803       case spv::Op::OpTypePointer:
804         if (spv::StorageClass(inst.GetSingleWordInOperand(0)) ==
805                 storage_class &&
806             inst.GetSingleWordInOperand(1) == pointee_type_id) {
807           return inst.result_id();
808         }
809         break;
810       default:
811         break;
812     }
813   }
814   return 0;
815 }
816 
InOperandIndexFromOperandIndex(const opt::Instruction & inst,uint32_t absolute_index)817 uint32_t InOperandIndexFromOperandIndex(const opt::Instruction& inst,
818                                         uint32_t absolute_index) {
819   // Subtract the number of non-input operands from the index
820   return absolute_index - inst.NumOperands() + inst.NumInOperands();
821 }
822 
IsNullConstantSupported(opt::IRContext * ir_context,const opt::Instruction & type_inst)823 bool IsNullConstantSupported(opt::IRContext* ir_context,
824                              const opt::Instruction& type_inst) {
825   switch (type_inst.opcode()) {
826     case spv::Op::OpTypeArray:
827     case spv::Op::OpTypeBool:
828     case spv::Op::OpTypeDeviceEvent:
829     case spv::Op::OpTypeEvent:
830     case spv::Op::OpTypeFloat:
831     case spv::Op::OpTypeInt:
832     case spv::Op::OpTypeMatrix:
833     case spv::Op::OpTypeQueue:
834     case spv::Op::OpTypeReserveId:
835     case spv::Op::OpTypeVector:
836     case spv::Op::OpTypeStruct:
837       return true;
838     case spv::Op::OpTypePointer:
839       // Null pointers are allowed if the VariablePointers capability is
840       // enabled, or if the VariablePointersStorageBuffer capability is enabled
841       // and the pointer type has StorageBuffer as its storage class.
842       if (ir_context->get_feature_mgr()->HasCapability(
843               spv::Capability::VariablePointers)) {
844         return true;
845       }
846       if (ir_context->get_feature_mgr()->HasCapability(
847               spv::Capability::VariablePointersStorageBuffer)) {
848         return spv::StorageClass(type_inst.GetSingleWordInOperand(0)) ==
849                spv::StorageClass::StorageBuffer;
850       }
851       return false;
852     default:
853       return false;
854   }
855 }
856 
GlobalVariablesMustBeDeclaredInEntryPointInterfaces(const opt::IRContext * ir_context)857 bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
858     const opt::IRContext* ir_context) {
859   // TODO(afd): We capture the environments for which this requirement holds.
860   //  The check should be refined on demand for other target environments.
861   switch (ir_context->grammar().target_env()) {
862     case SPV_ENV_UNIVERSAL_1_0:
863     case SPV_ENV_UNIVERSAL_1_1:
864     case SPV_ENV_UNIVERSAL_1_2:
865     case SPV_ENV_UNIVERSAL_1_3:
866     case SPV_ENV_VULKAN_1_0:
867     case SPV_ENV_VULKAN_1_1:
868       return false;
869     default:
870       return true;
871   }
872 }
873 
AddVariableIdToEntryPointInterfaces(opt::IRContext * context,uint32_t id)874 void AddVariableIdToEntryPointInterfaces(opt::IRContext* context, uint32_t id) {
875   if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(context)) {
876     // Conservatively add this global to the interface of every entry point in
877     // the module.  This means that the global is available for other
878     // transformations to use.
879     //
880     // A downside of this is that the global will be in the interface even if it
881     // ends up never being used.
882     //
883     // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit
884     //  this if a more thorough approach to entry point interfaces is taken.
885     for (auto& entry_point : context->module()->entry_points()) {
886       entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {id}});
887     }
888   }
889 }
890 
AddGlobalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,spv::StorageClass storage_class,uint32_t initializer_id)891 opt::Instruction* AddGlobalVariable(opt::IRContext* context, uint32_t result_id,
892                                     uint32_t type_id,
893                                     spv::StorageClass storage_class,
894                                     uint32_t initializer_id) {
895   // Check various preconditions.
896   assert(result_id != 0 && "Result id can't be 0");
897 
898   assert((storage_class == spv::StorageClass::Private ||
899           storage_class == spv::StorageClass::Workgroup) &&
900          "Variable's storage class must be either Private or Workgroup");
901 
902   auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
903   (void)type_inst;  // Variable becomes unused in release mode.
904   assert(type_inst && type_inst->opcode() == spv::Op::OpTypePointer &&
905          GetStorageClassFromPointerType(type_inst) == storage_class &&
906          "Variable's type is invalid");
907 
908   if (storage_class == spv::StorageClass::Workgroup) {
909     assert(initializer_id == 0);
910   }
911 
912   if (initializer_id != 0) {
913     const auto* constant_inst =
914         context->get_def_use_mgr()->GetDef(initializer_id);
915     (void)constant_inst;  // Variable becomes unused in release mode.
916     assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
917            GetPointeeTypeIdFromPointerType(type_inst) ==
918                constant_inst->type_id() &&
919            "Initializer is invalid");
920   }
921 
922   opt::Instruction::OperandList operands = {
923       {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast<uint32_t>(storage_class)}}};
924 
925   if (initializer_id) {
926     operands.push_back({SPV_OPERAND_TYPE_ID, {initializer_id}});
927   }
928 
929   auto new_instruction = MakeUnique<opt::Instruction>(
930       context, spv::Op::OpVariable, type_id, result_id, std::move(operands));
931   auto result = new_instruction.get();
932   context->module()->AddGlobalValue(std::move(new_instruction));
933 
934   AddVariableIdToEntryPointInterfaces(context, result_id);
935   UpdateModuleIdBound(context, result_id);
936 
937   return result;
938 }
939 
AddLocalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,uint32_t function_id,uint32_t initializer_id)940 opt::Instruction* AddLocalVariable(opt::IRContext* context, uint32_t result_id,
941                                    uint32_t type_id, uint32_t function_id,
942                                    uint32_t initializer_id) {
943   // Check various preconditions.
944   assert(result_id != 0 && "Result id can't be 0");
945 
946   auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
947   (void)type_inst;  // Variable becomes unused in release mode.
948   assert(type_inst && type_inst->opcode() == spv::Op::OpTypePointer &&
949          GetStorageClassFromPointerType(type_inst) ==
950              spv::StorageClass::Function &&
951          "Variable's type is invalid");
952 
953   const auto* constant_inst =
954       context->get_def_use_mgr()->GetDef(initializer_id);
955   (void)constant_inst;  // Variable becomes unused in release mode.
956   assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
957          GetPointeeTypeIdFromPointerType(type_inst) ==
958              constant_inst->type_id() &&
959          "Initializer is invalid");
960 
961   auto* function = FindFunction(context, function_id);
962   assert(function && "Function id is invalid");
963 
964   auto new_instruction = MakeUnique<opt::Instruction>(
965       context, spv::Op::OpVariable, type_id, result_id,
966       opt::Instruction::OperandList{{SPV_OPERAND_TYPE_STORAGE_CLASS,
967                                      {uint32_t(spv::StorageClass::Function)}},
968                                     {SPV_OPERAND_TYPE_ID, {initializer_id}}});
969   auto result = new_instruction.get();
970   function->begin()->begin()->InsertBefore(std::move(new_instruction));
971 
972   UpdateModuleIdBound(context, result_id);
973 
974   return result;
975 }
976 
HasDuplicates(const std::vector<uint32_t> & arr)977 bool HasDuplicates(const std::vector<uint32_t>& arr) {
978   return std::unordered_set<uint32_t>(arr.begin(), arr.end()).size() !=
979          arr.size();
980 }
981 
IsPermutationOfRange(const std::vector<uint32_t> & arr,uint32_t lo,uint32_t hi)982 bool IsPermutationOfRange(const std::vector<uint32_t>& arr, uint32_t lo,
983                           uint32_t hi) {
984   if (arr.empty()) {
985     return lo > hi;
986   }
987 
988   if (HasDuplicates(arr)) {
989     return false;
990   }
991 
992   auto min_max = std::minmax_element(arr.begin(), arr.end());
993   return arr.size() == hi - lo + 1 && *min_max.first == lo &&
994          *min_max.second == hi;
995 }
996 
GetParameters(opt::IRContext * ir_context,uint32_t function_id)997 std::vector<opt::Instruction*> GetParameters(opt::IRContext* ir_context,
998                                              uint32_t function_id) {
999   auto* function = FindFunction(ir_context, function_id);
1000   assert(function && "|function_id| is invalid");
1001 
1002   std::vector<opt::Instruction*> result;
1003   function->ForEachParam(
1004       [&result](opt::Instruction* inst) { result.push_back(inst); });
1005 
1006   return result;
1007 }
1008 
RemoveParameter(opt::IRContext * ir_context,uint32_t parameter_id)1009 void RemoveParameter(opt::IRContext* ir_context, uint32_t parameter_id) {
1010   auto* function = GetFunctionFromParameterId(ir_context, parameter_id);
1011   assert(function && "|parameter_id| is invalid");
1012   assert(!FunctionIsEntryPoint(ir_context, function->result_id()) &&
1013          "Can't remove parameter from an entry point function");
1014 
1015   function->RemoveParameter(parameter_id);
1016 
1017   // We've just removed parameters from the function and cleared their memory.
1018   // Make sure analyses have no dangling pointers.
1019   ir_context->InvalidateAnalysesExceptFor(
1020       opt::IRContext::Analysis::kAnalysisNone);
1021 }
1022 
GetCallers(opt::IRContext * ir_context,uint32_t function_id)1023 std::vector<opt::Instruction*> GetCallers(opt::IRContext* ir_context,
1024                                           uint32_t function_id) {
1025   assert(FindFunction(ir_context, function_id) &&
1026          "|function_id| is not a result id of a function");
1027 
1028   std::vector<opt::Instruction*> result;
1029   ir_context->get_def_use_mgr()->ForEachUser(
1030       function_id, [&result, function_id](opt::Instruction* inst) {
1031         if (inst->opcode() == spv::Op::OpFunctionCall &&
1032             inst->GetSingleWordInOperand(0) == function_id) {
1033           result.push_back(inst);
1034         }
1035       });
1036 
1037   return result;
1038 }
1039 
GetFunctionFromParameterId(opt::IRContext * ir_context,uint32_t param_id)1040 opt::Function* GetFunctionFromParameterId(opt::IRContext* ir_context,
1041                                           uint32_t param_id) {
1042   auto* param_inst = ir_context->get_def_use_mgr()->GetDef(param_id);
1043   assert(param_inst && "Parameter id is invalid");
1044 
1045   for (auto& function : *ir_context->module()) {
1046     if (InstructionIsFunctionParameter(param_inst, &function)) {
1047       return &function;
1048     }
1049   }
1050 
1051   return nullptr;
1052 }
1053 
UpdateFunctionType(opt::IRContext * ir_context,uint32_t function_id,uint32_t new_function_type_result_id,uint32_t return_type_id,const std::vector<uint32_t> & parameter_type_ids)1054 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
1055                             uint32_t new_function_type_result_id,
1056                             uint32_t return_type_id,
1057                             const std::vector<uint32_t>& parameter_type_ids) {
1058   // Check some initial constraints.
1059   assert(ir_context->get_type_mgr()->GetType(return_type_id) &&
1060          "Return type is invalid");
1061   for (auto id : parameter_type_ids) {
1062     const auto* type = ir_context->get_type_mgr()->GetType(id);
1063     (void)type;  // Make compilers happy in release mode.
1064     // Parameters can't be OpTypeVoid.
1065     assert(type && !type->AsVoid() && "Parameter has invalid type");
1066   }
1067 
1068   auto* function = FindFunction(ir_context, function_id);
1069   assert(function && "|function_id| is invalid");
1070 
1071   auto* old_function_type = GetFunctionType(ir_context, function);
1072   assert(old_function_type && "Function has invalid type");
1073 
1074   std::vector<uint32_t> operand_ids = {return_type_id};
1075   operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
1076                      parameter_type_ids.end());
1077 
1078   // A trivial case - we change nothing.
1079   if (FindFunctionType(ir_context, operand_ids) ==
1080       old_function_type->result_id()) {
1081     return old_function_type->result_id();
1082   }
1083 
1084   if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
1085       FindFunctionType(ir_context, operand_ids) == 0) {
1086     // We can change |old_function_type| only if it's used once in the module
1087     // and we are certain we won't create a duplicate as a result of the change.
1088 
1089     // Update |old_function_type| in-place.
1090     opt::Instruction::OperandList operands;
1091     for (auto id : operand_ids) {
1092       operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1093     }
1094 
1095     old_function_type->SetInOperands(std::move(operands));
1096 
1097     // |operands| may depend on result ids defined below the |old_function_type|
1098     // in the module.
1099     old_function_type->RemoveFromList();
1100     ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
1101     return old_function_type->result_id();
1102   } else {
1103     // We can't modify the |old_function_type| so we have to either use an
1104     // existing one or create a new one.
1105     auto type_id = FindOrCreateFunctionType(
1106         ir_context, new_function_type_result_id, operand_ids);
1107     assert(type_id != old_function_type->result_id() &&
1108            "We should've handled this case above");
1109 
1110     function->DefInst().SetInOperand(1, {type_id});
1111 
1112     // DefUseManager hasn't been updated yet, so if the following condition is
1113     // true, then |old_function_type| will have no users when this function
1114     // returns. We might as well remove it.
1115     if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
1116       ir_context->KillInst(old_function_type);
1117     }
1118 
1119     return type_id;
1120   }
1121 }
1122 
AddFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1123 void AddFunctionType(opt::IRContext* ir_context, uint32_t result_id,
1124                      const std::vector<uint32_t>& type_ids) {
1125   assert(result_id != 0 && "Result id can't be 0");
1126   assert(!type_ids.empty() &&
1127          "OpTypeFunction always has at least one operand - function's return "
1128          "type");
1129   assert(IsNonFunctionTypeId(ir_context, type_ids[0]) &&
1130          "Return type must not be a function");
1131 
1132   for (size_t i = 1; i < type_ids.size(); ++i) {
1133     const auto* param_type = ir_context->get_type_mgr()->GetType(type_ids[i]);
1134     (void)param_type;  // Make compiler happy in release mode.
1135     assert(param_type && !param_type->AsVoid() && !param_type->AsFunction() &&
1136            "Function parameter can't have a function or void type");
1137   }
1138 
1139   opt::Instruction::OperandList operands;
1140   operands.reserve(type_ids.size());
1141   for (auto id : type_ids) {
1142     operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1143   }
1144 
1145   ir_context->AddType(MakeUnique<opt::Instruction>(
1146       ir_context, spv::Op::OpTypeFunction, 0, result_id, std::move(operands)));
1147 
1148   UpdateModuleIdBound(ir_context, result_id);
1149 }
1150 
FindOrCreateFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1151 uint32_t FindOrCreateFunctionType(opt::IRContext* ir_context,
1152                                   uint32_t result_id,
1153                                   const std::vector<uint32_t>& type_ids) {
1154   if (auto existing_id = FindFunctionType(ir_context, type_ids)) {
1155     return existing_id;
1156   }
1157   AddFunctionType(ir_context, result_id, type_ids);
1158   return result_id;
1159 }
1160 
MaybeGetIntegerType(opt::IRContext * ir_context,uint32_t width,bool is_signed)1161 uint32_t MaybeGetIntegerType(opt::IRContext* ir_context, uint32_t width,
1162                              bool is_signed) {
1163   opt::analysis::Integer type(width, is_signed);
1164   return ir_context->get_type_mgr()->GetId(&type);
1165 }
1166 
MaybeGetFloatType(opt::IRContext * ir_context,uint32_t width)1167 uint32_t MaybeGetFloatType(opt::IRContext* ir_context, uint32_t width) {
1168   opt::analysis::Float type(width);
1169   return ir_context->get_type_mgr()->GetId(&type);
1170 }
1171 
MaybeGetBoolType(opt::IRContext * ir_context)1172 uint32_t MaybeGetBoolType(opt::IRContext* ir_context) {
1173   opt::analysis::Bool type;
1174   return ir_context->get_type_mgr()->GetId(&type);
1175 }
1176 
MaybeGetVectorType(opt::IRContext * ir_context,uint32_t component_type_id,uint32_t element_count)1177 uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
1178                             uint32_t component_type_id,
1179                             uint32_t element_count) {
1180   const auto* component_type =
1181       ir_context->get_type_mgr()->GetType(component_type_id);
1182   assert(component_type &&
1183          (component_type->AsInteger() || component_type->AsFloat() ||
1184           component_type->AsBool()) &&
1185          "|component_type_id| is invalid");
1186   assert(element_count >= 2 && element_count <= 4 &&
1187          "Precondition: component count must be in range [2, 4].");
1188   opt::analysis::Vector type(component_type, element_count);
1189   return ir_context->get_type_mgr()->GetId(&type);
1190 }
1191 
MaybeGetStructType(opt::IRContext * ir_context,const std::vector<uint32_t> & component_type_ids)1192 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
1193                             const std::vector<uint32_t>& component_type_ids) {
1194   for (auto& type_or_value : ir_context->types_values()) {
1195     if (type_or_value.opcode() != spv::Op::OpTypeStruct ||
1196         type_or_value.NumInOperands() !=
1197             static_cast<uint32_t>(component_type_ids.size())) {
1198       continue;
1199     }
1200     bool all_components_match = true;
1201     for (uint32_t i = 0; i < component_type_ids.size(); i++) {
1202       if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
1203         all_components_match = false;
1204         break;
1205       }
1206     }
1207     if (all_components_match) {
1208       return type_or_value.result_id();
1209     }
1210   }
1211   return 0;
1212 }
1213 
MaybeGetVoidType(opt::IRContext * ir_context)1214 uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
1215   opt::analysis::Void type;
1216   return ir_context->get_type_mgr()->GetId(&type);
1217 }
1218 
MaybeGetZeroConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,uint32_t scalar_or_composite_type_id,bool is_irrelevant)1219 uint32_t MaybeGetZeroConstant(
1220     opt::IRContext* ir_context,
1221     const TransformationContext& transformation_context,
1222     uint32_t scalar_or_composite_type_id, bool is_irrelevant) {
1223   const auto* type_inst =
1224       ir_context->get_def_use_mgr()->GetDef(scalar_or_composite_type_id);
1225   assert(type_inst && "|scalar_or_composite_type_id| is invalid");
1226 
1227   switch (type_inst->opcode()) {
1228     case spv::Op::OpTypeBool:
1229       return MaybeGetBoolConstant(ir_context, transformation_context, false,
1230                                   is_irrelevant);
1231     case spv::Op::OpTypeFloat:
1232     case spv::Op::OpTypeInt: {
1233       const auto width = type_inst->GetSingleWordInOperand(0);
1234       std::vector<uint32_t> words = {0};
1235       if (width > 32) {
1236         words.push_back(0);
1237       }
1238 
1239       return MaybeGetScalarConstant(ir_context, transformation_context, words,
1240                                     scalar_or_composite_type_id, is_irrelevant);
1241     }
1242     case spv::Op::OpTypeStruct: {
1243       std::vector<uint32_t> component_ids;
1244       for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
1245         const auto component_type_id = type_inst->GetSingleWordInOperand(i);
1246 
1247         auto component_id =
1248             MaybeGetZeroConstant(ir_context, transformation_context,
1249                                  component_type_id, is_irrelevant);
1250 
1251         if (component_id == 0 && is_irrelevant) {
1252           // Irrelevant constants can use either relevant or irrelevant
1253           // constituents.
1254           component_id = MaybeGetZeroConstant(
1255               ir_context, transformation_context, component_type_id, false);
1256         }
1257 
1258         if (component_id == 0) {
1259           return 0;
1260         }
1261 
1262         component_ids.push_back(component_id);
1263       }
1264 
1265       return MaybeGetCompositeConstant(
1266           ir_context, transformation_context, component_ids,
1267           scalar_or_composite_type_id, is_irrelevant);
1268     }
1269     case spv::Op::OpTypeMatrix:
1270     case spv::Op::OpTypeVector: {
1271       const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1272 
1273       auto component_id = MaybeGetZeroConstant(
1274           ir_context, transformation_context, component_type_id, is_irrelevant);
1275 
1276       if (component_id == 0 && is_irrelevant) {
1277         // Irrelevant constants can use either relevant or irrelevant
1278         // constituents.
1279         component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1280                                             component_type_id, false);
1281       }
1282 
1283       if (component_id == 0) {
1284         return 0;
1285       }
1286 
1287       const auto component_count = type_inst->GetSingleWordInOperand(1);
1288       return MaybeGetCompositeConstant(
1289           ir_context, transformation_context,
1290           std::vector<uint32_t>(component_count, component_id),
1291           scalar_or_composite_type_id, is_irrelevant);
1292     }
1293     case spv::Op::OpTypeArray: {
1294       const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1295 
1296       auto component_id = MaybeGetZeroConstant(
1297           ir_context, transformation_context, component_type_id, is_irrelevant);
1298 
1299       if (component_id == 0 && is_irrelevant) {
1300         // Irrelevant constants can use either relevant or irrelevant
1301         // constituents.
1302         component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1303                                             component_type_id, false);
1304       }
1305 
1306       if (component_id == 0) {
1307         return 0;
1308       }
1309 
1310       return MaybeGetCompositeConstant(
1311           ir_context, transformation_context,
1312           std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
1313                                 component_id),
1314           scalar_or_composite_type_id, is_irrelevant);
1315     }
1316     default:
1317       assert(false && "Type is not supported");
1318       return 0;
1319   }
1320 }
1321 
CanCreateConstant(opt::IRContext * ir_context,uint32_t type_id)1322 bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
1323   opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
1324   assert(type_instr != nullptr && "The type must exist.");
1325   assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
1326          "A type-generating opcode was expected.");
1327   switch (type_instr->opcode()) {
1328     case spv::Op::OpTypeBool:
1329     case spv::Op::OpTypeInt:
1330     case spv::Op::OpTypeFloat:
1331     case spv::Op::OpTypeMatrix:
1332     case spv::Op::OpTypeVector:
1333       return true;
1334     case spv::Op::OpTypeArray:
1335       return CanCreateConstant(ir_context,
1336                                type_instr->GetSingleWordInOperand(0));
1337     case spv::Op::OpTypeStruct:
1338       if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
1339         return false;
1340       }
1341       for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
1342         if (!CanCreateConstant(ir_context,
1343                                type_instr->GetSingleWordInOperand(index))) {
1344           return false;
1345         }
1346       }
1347       return true;
1348     default:
1349       return false;
1350   }
1351 }
1352 
MaybeGetScalarConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t scalar_type_id,bool is_irrelevant)1353 uint32_t MaybeGetScalarConstant(
1354     opt::IRContext* ir_context,
1355     const TransformationContext& transformation_context,
1356     const std::vector<uint32_t>& words, uint32_t scalar_type_id,
1357     bool is_irrelevant) {
1358   const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
1359   assert(type && "|scalar_type_id| is invalid");
1360 
1361   if (const auto* int_type = type->AsInteger()) {
1362     return MaybeGetIntegerConstant(ir_context, transformation_context, words,
1363                                    int_type->width(), int_type->IsSigned(),
1364                                    is_irrelevant);
1365   } else if (const auto* float_type = type->AsFloat()) {
1366     return MaybeGetFloatConstant(ir_context, transformation_context, words,
1367                                  float_type->width(), is_irrelevant);
1368   } else {
1369     assert(type->AsBool() && words.size() == 1 &&
1370            "|scalar_type_id| doesn't represent a scalar type");
1371     return MaybeGetBoolConstant(ir_context, transformation_context, words[0],
1372                                 is_irrelevant);
1373   }
1374 }
1375 
MaybeGetCompositeConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & component_ids,uint32_t composite_type_id,bool is_irrelevant)1376 uint32_t MaybeGetCompositeConstant(
1377     opt::IRContext* ir_context,
1378     const TransformationContext& transformation_context,
1379     const std::vector<uint32_t>& component_ids, uint32_t composite_type_id,
1380     bool is_irrelevant) {
1381   const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
1382   (void)type;  // Make compilers happy in release mode.
1383   assert(IsCompositeType(type) && "|composite_type_id| is invalid");
1384 
1385   for (const auto& inst : ir_context->types_values()) {
1386     if (inst.opcode() == spv::Op::OpConstantComposite &&
1387         inst.type_id() == composite_type_id &&
1388         transformation_context.GetFactManager()->IdIsIrrelevant(
1389             inst.result_id()) == is_irrelevant &&
1390         inst.NumInOperands() == component_ids.size()) {
1391       bool is_match = true;
1392 
1393       for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
1394         if (inst.GetSingleWordInOperand(i) != component_ids[i]) {
1395           is_match = false;
1396           break;
1397         }
1398       }
1399 
1400       if (is_match) {
1401         return inst.result_id();
1402       }
1403     }
1404   }
1405 
1406   return 0;
1407 }
1408 
MaybeGetIntegerConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_signed,bool is_irrelevant)1409 uint32_t MaybeGetIntegerConstant(
1410     opt::IRContext* ir_context,
1411     const TransformationContext& transformation_context,
1412     const std::vector<uint32_t>& words, uint32_t width, bool is_signed,
1413     bool is_irrelevant) {
1414   if (auto type_id = MaybeGetIntegerType(ir_context, width, is_signed)) {
1415     return MaybeGetOpConstant(ir_context, transformation_context, words,
1416                               type_id, is_irrelevant);
1417   }
1418 
1419   return 0;
1420 }
1421 
MaybeGetIntegerConstantFromValueAndType(opt::IRContext * ir_context,uint32_t value,uint32_t int_type_id)1422 uint32_t MaybeGetIntegerConstantFromValueAndType(opt::IRContext* ir_context,
1423                                                  uint32_t value,
1424                                                  uint32_t int_type_id) {
1425   auto int_type_inst = ir_context->get_def_use_mgr()->GetDef(int_type_id);
1426 
1427   assert(int_type_inst && "The given type id must exist.");
1428 
1429   auto int_type = ir_context->get_type_mgr()
1430                       ->GetType(int_type_inst->result_id())
1431                       ->AsInteger();
1432 
1433   assert(int_type && int_type->width() == 32 &&
1434          "The given type id must correspond to an 32-bit integer type.");
1435 
1436   opt::analysis::IntConstant constant(int_type, {value});
1437 
1438   // Check that the constant exists in the module.
1439   if (!ir_context->get_constant_mgr()->FindConstant(&constant)) {
1440     return 0;
1441   }
1442 
1443   return ir_context->get_constant_mgr()
1444       ->GetDefiningInstruction(&constant)
1445       ->result_id();
1446 }
1447 
MaybeGetFloatConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_irrelevant)1448 uint32_t MaybeGetFloatConstant(
1449     opt::IRContext* ir_context,
1450     const TransformationContext& transformation_context,
1451     const std::vector<uint32_t>& words, uint32_t width, bool is_irrelevant) {
1452   if (auto type_id = MaybeGetFloatType(ir_context, width)) {
1453     return MaybeGetOpConstant(ir_context, transformation_context, words,
1454                               type_id, is_irrelevant);
1455   }
1456 
1457   return 0;
1458 }
1459 
MaybeGetBoolConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,bool value,bool is_irrelevant)1460 uint32_t MaybeGetBoolConstant(
1461     opt::IRContext* ir_context,
1462     const TransformationContext& transformation_context, bool value,
1463     bool is_irrelevant) {
1464   if (auto type_id = MaybeGetBoolType(ir_context)) {
1465     for (const auto& inst : ir_context->types_values()) {
1466       if (inst.opcode() ==
1467               (value ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse) &&
1468           inst.type_id() == type_id &&
1469           transformation_context.GetFactManager()->IdIsIrrelevant(
1470               inst.result_id()) == is_irrelevant) {
1471         return inst.result_id();
1472       }
1473     }
1474   }
1475 
1476   return 0;
1477 }
1478 
IntToWords(uint64_t value,uint32_t width,bool is_signed)1479 std::vector<uint32_t> IntToWords(uint64_t value, uint32_t width,
1480                                  bool is_signed) {
1481   assert(width <= 64 && "The bit width should not be more than 64 bits");
1482 
1483   // Sign-extend or zero-extend the last |width| bits of |value|, depending on
1484   // |is_signed|.
1485   if (is_signed) {
1486     // Sign-extend by shifting left and then shifting right, interpreting the
1487     // integer as signed.
1488     value = static_cast<int64_t>(value << (64 - width)) >> (64 - width);
1489   } else {
1490     // Zero-extend by shifting left and then shifting right, interpreting the
1491     // integer as unsigned.
1492     value = (value << (64 - width)) >> (64 - width);
1493   }
1494 
1495   std::vector<uint32_t> result;
1496   result.push_back(static_cast<uint32_t>(value));
1497   if (width > 32) {
1498     result.push_back(static_cast<uint32_t>(value >> 32));
1499   }
1500   return result;
1501 }
1502 
TypesAreEqualUpToSign(opt::IRContext * ir_context,uint32_t type1_id,uint32_t type2_id)1503 bool TypesAreEqualUpToSign(opt::IRContext* ir_context, uint32_t type1_id,
1504                            uint32_t type2_id) {
1505   if (type1_id == type2_id) {
1506     return true;
1507   }
1508 
1509   auto type1 = ir_context->get_type_mgr()->GetType(type1_id);
1510   auto type2 = ir_context->get_type_mgr()->GetType(type2_id);
1511 
1512   // Integer scalar types must have the same width
1513   if (type1->AsInteger() && type2->AsInteger()) {
1514     return type1->AsInteger()->width() == type2->AsInteger()->width();
1515   }
1516 
1517   // Integer vector types must have the same number of components and their
1518   // component types must be integers with the same width.
1519   if (type1->AsVector() && type2->AsVector()) {
1520     auto component_type1 = type1->AsVector()->element_type()->AsInteger();
1521     auto component_type2 = type2->AsVector()->element_type()->AsInteger();
1522 
1523     // Only check the component count and width if they are integer.
1524     if (component_type1 && component_type2) {
1525       return type1->AsVector()->element_count() ==
1526                  type2->AsVector()->element_count() &&
1527              component_type1->width() == component_type2->width();
1528     }
1529   }
1530 
1531   // In all other cases, the types cannot be considered equal.
1532   return false;
1533 }
1534 
RepeatedUInt32PairToMap(const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> & data)1535 std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
1536     const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>& data) {
1537   std::map<uint32_t, uint32_t> result;
1538 
1539   for (const auto& entry : data) {
1540     result[entry.first()] = entry.second();
1541   }
1542 
1543   return result;
1544 }
1545 
1546 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t,uint32_t> & data)1547 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
1548   google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> result;
1549 
1550   for (const auto& entry : data) {
1551     protobufs::UInt32Pair pair;
1552     pair.set_first(entry.first);
1553     pair.set_second(entry.second);
1554     *result.Add() = std::move(pair);
1555   }
1556 
1557   return result;
1558 }
1559 
GetLastInsertBeforeInstruction(opt::IRContext * ir_context,uint32_t block_id,spv::Op opcode)1560 opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
1561                                                  uint32_t block_id,
1562                                                  spv::Op opcode) {
1563   // CFG::block uses std::map::at which throws an exception when |block_id| is
1564   // invalid. The error message is unhelpful, though. Thus, we test that
1565   // |block_id| is valid here.
1566   const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
1567   (void)label_inst;  // Make compilers happy in release mode.
1568   assert(label_inst && label_inst->opcode() == spv::Op::OpLabel &&
1569          "|block_id| is invalid");
1570 
1571   auto* block = ir_context->cfg()->block(block_id);
1572   auto it = block->rbegin();
1573   assert(it != block->rend() && "Basic block can't be empty");
1574 
1575   if (block->GetMergeInst()) {
1576     ++it;
1577     assert(it != block->rend() &&
1578            "|block| must have at least two instructions:"
1579            "terminator and a merge instruction");
1580   }
1581 
1582   return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
1583 }
1584 
IdUseCanBeReplaced(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * use_instruction,uint32_t use_in_operand_index)1585 bool IdUseCanBeReplaced(opt::IRContext* ir_context,
1586                         const TransformationContext& transformation_context,
1587                         opt::Instruction* use_instruction,
1588                         uint32_t use_in_operand_index) {
1589   if (spvOpcodeIsAccessChain(use_instruction->opcode()) &&
1590       use_in_operand_index > 0) {
1591     // A replacement for an irrelevant index in OpAccessChain must be clamped
1592     // first.
1593     if (transformation_context.GetFactManager()->IdIsIrrelevant(
1594             use_instruction->GetSingleWordInOperand(use_in_operand_index))) {
1595       return false;
1596     }
1597 
1598     // This is an access chain index.  If the (sub-)object being accessed by the
1599     // given index has struct type then we cannot replace the use, as it needs
1600     // to be an OpConstant.
1601 
1602     // Get the top-level composite type that is being accessed.
1603     auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef(
1604         use_instruction->GetSingleWordInOperand(0));
1605     auto pointer_type =
1606         ir_context->get_type_mgr()->GetType(object_being_accessed->type_id());
1607     assert(pointer_type->AsPointer());
1608     auto composite_type_being_accessed =
1609         pointer_type->AsPointer()->pointee_type();
1610 
1611     // Now walk the access chain, tracking the type of each sub-object of the
1612     // composite that is traversed, until the index of interest is reached.
1613     for (uint32_t index_in_operand = 1; index_in_operand < use_in_operand_index;
1614          index_in_operand++) {
1615       // For vectors, matrices and arrays, getting the type of the sub-object is
1616       // trivial. For the struct case, the sub-object type is field-sensitive,
1617       // and depends on the constant index that is used.
1618       if (composite_type_being_accessed->AsVector()) {
1619         composite_type_being_accessed =
1620             composite_type_being_accessed->AsVector()->element_type();
1621       } else if (composite_type_being_accessed->AsMatrix()) {
1622         composite_type_being_accessed =
1623             composite_type_being_accessed->AsMatrix()->element_type();
1624       } else if (composite_type_being_accessed->AsArray()) {
1625         composite_type_being_accessed =
1626             composite_type_being_accessed->AsArray()->element_type();
1627       } else if (composite_type_being_accessed->AsRuntimeArray()) {
1628         composite_type_being_accessed =
1629             composite_type_being_accessed->AsRuntimeArray()->element_type();
1630       } else {
1631         assert(composite_type_being_accessed->AsStruct());
1632         auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef(
1633             use_instruction->GetSingleWordInOperand(index_in_operand));
1634         assert(constant_index_instruction->opcode() == spv::Op::OpConstant);
1635         uint32_t member_index =
1636             constant_index_instruction->GetSingleWordInOperand(0);
1637         composite_type_being_accessed =
1638             composite_type_being_accessed->AsStruct()
1639                 ->element_types()[member_index];
1640       }
1641     }
1642 
1643     // We have found the composite type being accessed by the index we are
1644     // considering replacing. If it is a struct, then we cannot do the
1645     // replacement as struct indices must be constants.
1646     if (composite_type_being_accessed->AsStruct()) {
1647       return false;
1648     }
1649   }
1650 
1651   if (use_instruction->opcode() == spv::Op::OpFunctionCall &&
1652       use_in_operand_index > 0) {
1653     // This is a function call argument.  It is not allowed to have pointer
1654     // type.
1655 
1656     // Get the definition of the function being called.
1657     auto function = ir_context->get_def_use_mgr()->GetDef(
1658         use_instruction->GetSingleWordInOperand(0));
1659     // From the function definition, get the function type.
1660     auto function_type = ir_context->get_def_use_mgr()->GetDef(
1661         function->GetSingleWordInOperand(1));
1662     // OpTypeFunction's 0-th input operand is the function return type, and the
1663     // function argument types follow. Because the arguments to OpFunctionCall
1664     // start from input operand 1, we can use |use_in_operand_index| to get the
1665     // type associated with this function argument.
1666     auto parameter_type = ir_context->get_type_mgr()->GetType(
1667         function_type->GetSingleWordInOperand(use_in_operand_index));
1668     if (parameter_type->AsPointer()) {
1669       return false;
1670     }
1671   }
1672 
1673   if (use_instruction->opcode() == spv::Op::OpImageTexelPointer &&
1674       use_in_operand_index == 2) {
1675     // The OpImageTexelPointer instruction has a Sample parameter that in some
1676     // situations must be an id for the value 0.  To guard against disrupting
1677     // that requirement, we do not replace this argument to that instruction.
1678     return false;
1679   }
1680 
1681   if (ir_context->get_feature_mgr()->HasCapability(spv::Capability::Shader)) {
1682     // With the Shader capability, memory scope and memory semantics operands
1683     // are required to be constants, so they cannot be replaced arbitrarily.
1684     switch (use_instruction->opcode()) {
1685       case spv::Op::OpAtomicLoad:
1686       case spv::Op::OpAtomicStore:
1687       case spv::Op::OpAtomicExchange:
1688       case spv::Op::OpAtomicIIncrement:
1689       case spv::Op::OpAtomicIDecrement:
1690       case spv::Op::OpAtomicIAdd:
1691       case spv::Op::OpAtomicISub:
1692       case spv::Op::OpAtomicSMin:
1693       case spv::Op::OpAtomicUMin:
1694       case spv::Op::OpAtomicSMax:
1695       case spv::Op::OpAtomicUMax:
1696       case spv::Op::OpAtomicAnd:
1697       case spv::Op::OpAtomicOr:
1698       case spv::Op::OpAtomicXor:
1699         if (use_in_operand_index == 1 || use_in_operand_index == 2) {
1700           return false;
1701         }
1702         break;
1703       case spv::Op::OpAtomicCompareExchange:
1704         if (use_in_operand_index == 1 || use_in_operand_index == 2 ||
1705             use_in_operand_index == 3) {
1706           return false;
1707         }
1708         break;
1709       case spv::Op::OpAtomicCompareExchangeWeak:
1710       case spv::Op::OpAtomicFlagTestAndSet:
1711       case spv::Op::OpAtomicFlagClear:
1712       case spv::Op::OpAtomicFAddEXT:
1713         assert(false && "Not allowed with the Shader capability.");
1714       default:
1715         break;
1716     }
1717   }
1718 
1719   return true;
1720 }
1721 
MembersHaveBuiltInDecoration(opt::IRContext * ir_context,uint32_t struct_type_id)1722 bool MembersHaveBuiltInDecoration(opt::IRContext* ir_context,
1723                                   uint32_t struct_type_id) {
1724   const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(struct_type_id);
1725   assert(type_inst && type_inst->opcode() == spv::Op::OpTypeStruct &&
1726          "|struct_type_id| is not a result id of an OpTypeStruct");
1727 
1728   uint32_t builtin_count = 0;
1729   ir_context->get_def_use_mgr()->ForEachUser(
1730       type_inst,
1731       [struct_type_id, &builtin_count](const opt::Instruction* user) {
1732         if (user->opcode() == spv::Op::OpMemberDecorate &&
1733             user->GetSingleWordInOperand(0) == struct_type_id &&
1734             static_cast<spv::Decoration>(user->GetSingleWordInOperand(2)) ==
1735                 spv::Decoration::BuiltIn) {
1736           ++builtin_count;
1737         }
1738       });
1739 
1740   assert((builtin_count == 0 || builtin_count == type_inst->NumInOperands()) &&
1741          "The module is invalid: either none or all of the members of "
1742          "|struct_type_id| may be builtin");
1743 
1744   return builtin_count != 0;
1745 }
1746 
HasBlockOrBufferBlockDecoration(opt::IRContext * ir_context,uint32_t id)1747 bool HasBlockOrBufferBlockDecoration(opt::IRContext* ir_context, uint32_t id) {
1748   for (auto decoration :
1749        {spv::Decoration::Block, spv::Decoration::BufferBlock}) {
1750     if (!ir_context->get_decoration_mgr()->WhileEachDecoration(
1751             id, uint32_t(decoration),
1752             [](const opt::Instruction & /*unused*/) -> bool {
1753               return false;
1754             })) {
1755       return true;
1756     }
1757   }
1758   return false;
1759 }
1760 
SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(opt::BasicBlock * block_to_split,opt::Instruction * split_before)1761 bool SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(
1762     opt::BasicBlock* block_to_split, opt::Instruction* split_before) {
1763   std::set<uint32_t> sampled_image_result_ids;
1764   bool before_split = true;
1765 
1766   // Check all the instructions in the block to split.
1767   for (auto& instruction : *block_to_split) {
1768     if (&instruction == &*split_before) {
1769       before_split = false;
1770     }
1771     if (before_split) {
1772       // If the instruction comes before the split and its opcode is
1773       // OpSampledImage, record its result id.
1774       if (instruction.opcode() == spv::Op::OpSampledImage) {
1775         sampled_image_result_ids.insert(instruction.result_id());
1776       }
1777     } else {
1778       // If the instruction comes after the split, check if ids
1779       // corresponding to OpSampledImage instructions defined before the split
1780       // are used, and return true if they are.
1781       if (!instruction.WhileEachInId(
1782               [&sampled_image_result_ids](uint32_t* id) -> bool {
1783                 return !sampled_image_result_ids.count(*id);
1784               })) {
1785         return true;
1786       }
1787     }
1788   }
1789 
1790   // No usage that would be separated from the definition has been found.
1791   return false;
1792 }
1793 
InstructionHasNoSideEffects(const opt::Instruction & instruction)1794 bool InstructionHasNoSideEffects(const opt::Instruction& instruction) {
1795   switch (instruction.opcode()) {
1796     case spv::Op::OpUndef:
1797     case spv::Op::OpAccessChain:
1798     case spv::Op::OpInBoundsAccessChain:
1799     case spv::Op::OpArrayLength:
1800     case spv::Op::OpVectorExtractDynamic:
1801     case spv::Op::OpVectorInsertDynamic:
1802     case spv::Op::OpVectorShuffle:
1803     case spv::Op::OpCompositeConstruct:
1804     case spv::Op::OpCompositeExtract:
1805     case spv::Op::OpCompositeInsert:
1806     case spv::Op::OpCopyObject:
1807     case spv::Op::OpTranspose:
1808     case spv::Op::OpConvertFToU:
1809     case spv::Op::OpConvertFToS:
1810     case spv::Op::OpConvertSToF:
1811     case spv::Op::OpConvertUToF:
1812     case spv::Op::OpUConvert:
1813     case spv::Op::OpSConvert:
1814     case spv::Op::OpFConvert:
1815     case spv::Op::OpQuantizeToF16:
1816     case spv::Op::OpSatConvertSToU:
1817     case spv::Op::OpSatConvertUToS:
1818     case spv::Op::OpBitcast:
1819     case spv::Op::OpSNegate:
1820     case spv::Op::OpFNegate:
1821     case spv::Op::OpIAdd:
1822     case spv::Op::OpFAdd:
1823     case spv::Op::OpISub:
1824     case spv::Op::OpFSub:
1825     case spv::Op::OpIMul:
1826     case spv::Op::OpFMul:
1827     case spv::Op::OpUDiv:
1828     case spv::Op::OpSDiv:
1829     case spv::Op::OpFDiv:
1830     case spv::Op::OpUMod:
1831     case spv::Op::OpSRem:
1832     case spv::Op::OpSMod:
1833     case spv::Op::OpFRem:
1834     case spv::Op::OpFMod:
1835     case spv::Op::OpVectorTimesScalar:
1836     case spv::Op::OpMatrixTimesScalar:
1837     case spv::Op::OpVectorTimesMatrix:
1838     case spv::Op::OpMatrixTimesVector:
1839     case spv::Op::OpMatrixTimesMatrix:
1840     case spv::Op::OpOuterProduct:
1841     case spv::Op::OpDot:
1842     case spv::Op::OpIAddCarry:
1843     case spv::Op::OpISubBorrow:
1844     case spv::Op::OpUMulExtended:
1845     case spv::Op::OpSMulExtended:
1846     case spv::Op::OpAny:
1847     case spv::Op::OpAll:
1848     case spv::Op::OpIsNan:
1849     case spv::Op::OpIsInf:
1850     case spv::Op::OpIsFinite:
1851     case spv::Op::OpIsNormal:
1852     case spv::Op::OpSignBitSet:
1853     case spv::Op::OpLessOrGreater:
1854     case spv::Op::OpOrdered:
1855     case spv::Op::OpUnordered:
1856     case spv::Op::OpLogicalEqual:
1857     case spv::Op::OpLogicalNotEqual:
1858     case spv::Op::OpLogicalOr:
1859     case spv::Op::OpLogicalAnd:
1860     case spv::Op::OpLogicalNot:
1861     case spv::Op::OpSelect:
1862     case spv::Op::OpIEqual:
1863     case spv::Op::OpINotEqual:
1864     case spv::Op::OpUGreaterThan:
1865     case spv::Op::OpSGreaterThan:
1866     case spv::Op::OpUGreaterThanEqual:
1867     case spv::Op::OpSGreaterThanEqual:
1868     case spv::Op::OpULessThan:
1869     case spv::Op::OpSLessThan:
1870     case spv::Op::OpULessThanEqual:
1871     case spv::Op::OpSLessThanEqual:
1872     case spv::Op::OpFOrdEqual:
1873     case spv::Op::OpFUnordEqual:
1874     case spv::Op::OpFOrdNotEqual:
1875     case spv::Op::OpFUnordNotEqual:
1876     case spv::Op::OpFOrdLessThan:
1877     case spv::Op::OpFUnordLessThan:
1878     case spv::Op::OpFOrdGreaterThan:
1879     case spv::Op::OpFUnordGreaterThan:
1880     case spv::Op::OpFOrdLessThanEqual:
1881     case spv::Op::OpFUnordLessThanEqual:
1882     case spv::Op::OpFOrdGreaterThanEqual:
1883     case spv::Op::OpFUnordGreaterThanEqual:
1884     case spv::Op::OpShiftRightLogical:
1885     case spv::Op::OpShiftRightArithmetic:
1886     case spv::Op::OpShiftLeftLogical:
1887     case spv::Op::OpBitwiseOr:
1888     case spv::Op::OpBitwiseXor:
1889     case spv::Op::OpBitwiseAnd:
1890     case spv::Op::OpNot:
1891     case spv::Op::OpBitFieldInsert:
1892     case spv::Op::OpBitFieldSExtract:
1893     case spv::Op::OpBitFieldUExtract:
1894     case spv::Op::OpBitReverse:
1895     case spv::Op::OpBitCount:
1896     case spv::Op::OpCopyLogical:
1897     case spv::Op::OpPhi:
1898     case spv::Op::OpPtrEqual:
1899     case spv::Op::OpPtrNotEqual:
1900       return true;
1901     default:
1902       return false;
1903   }
1904 }
1905 
GetReachableReturnBlocks(opt::IRContext * ir_context,uint32_t function_id)1906 std::set<uint32_t> GetReachableReturnBlocks(opt::IRContext* ir_context,
1907                                             uint32_t function_id) {
1908   auto function = ir_context->GetFunction(function_id);
1909   assert(function && "The function |function_id| must exist.");
1910 
1911   std::set<uint32_t> result;
1912 
1913   ir_context->cfg()->ForEachBlockInPostOrder(function->entry().get(),
1914                                              [&result](opt::BasicBlock* block) {
1915                                                if (block->IsReturn()) {
1916                                                  result.emplace(block->id());
1917                                                }
1918                                              });
1919 
1920   return result;
1921 }
1922 
NewTerminatorPreservesDominationRules(opt::IRContext * ir_context,uint32_t block_id,opt::Instruction new_terminator)1923 bool NewTerminatorPreservesDominationRules(opt::IRContext* ir_context,
1924                                            uint32_t block_id,
1925                                            opt::Instruction new_terminator) {
1926   auto* mutated_block = MaybeFindBlock(ir_context, block_id);
1927   assert(mutated_block && "|block_id| is invalid");
1928 
1929   ChangeTerminatorRAII change_terminator_raii(mutated_block,
1930                                               std::move(new_terminator));
1931   opt::DominatorAnalysis dominator_analysis;
1932   dominator_analysis.InitializeTree(*ir_context->cfg(),
1933                                     mutated_block->GetParent());
1934 
1935   // Check that each dominator appears before each dominated block.
1936   std::unordered_map<uint32_t, size_t> positions;
1937   for (const auto& block : *mutated_block->GetParent()) {
1938     positions[block.id()] = positions.size();
1939   }
1940 
1941   std::queue<uint32_t> q({mutated_block->GetParent()->begin()->id()});
1942   std::unordered_set<uint32_t> visited;
1943   while (!q.empty()) {
1944     auto block = q.front();
1945     q.pop();
1946     visited.insert(block);
1947 
1948     auto success = ir_context->cfg()->block(block)->WhileEachSuccessorLabel(
1949         [&positions, &visited, &dominator_analysis, block, &q](uint32_t id) {
1950           if (id == block) {
1951             // Handle the case when loop header and continue target are the same
1952             // block.
1953             return true;
1954           }
1955 
1956           if (dominator_analysis.Dominates(block, id) &&
1957               positions[block] > positions[id]) {
1958             // |block| dominates |id| but appears after |id| - violates
1959             // domination rules.
1960             return false;
1961           }
1962 
1963           if (!visited.count(id)) {
1964             q.push(id);
1965           }
1966 
1967           return true;
1968         });
1969 
1970     if (!success) {
1971       return false;
1972     }
1973   }
1974 
1975   // For each instruction in the |block->GetParent()| function check whether
1976   // all its dependencies satisfy domination rules (i.e. all id operands
1977   // dominate that instruction).
1978   for (const auto& block : *mutated_block->GetParent()) {
1979     if (!ir_context->IsReachable(block)) {
1980       // If some block is not reachable then we don't need to worry about the
1981       // preservation of domination rules for its instructions.
1982       continue;
1983     }
1984 
1985     for (const auto& inst : block) {
1986       for (uint32_t i = 0; i < inst.NumInOperands();
1987            i += inst.opcode() == spv::Op::OpPhi ? 2 : 1) {
1988         const auto& operand = inst.GetInOperand(i);
1989         if (!spvIsInIdType(operand.type)) {
1990           continue;
1991         }
1992 
1993         if (MaybeFindBlock(ir_context, operand.words[0])) {
1994           // Ignore operands that refer to OpLabel instructions.
1995           continue;
1996         }
1997 
1998         const auto* dependency_block =
1999             ir_context->get_instr_block(operand.words[0]);
2000         if (!dependency_block) {
2001           // A global instruction always dominates all instructions in any
2002           // function.
2003           continue;
2004         }
2005 
2006         auto domination_target_id = inst.opcode() == spv::Op::OpPhi
2007                                         ? inst.GetSingleWordInOperand(i + 1)
2008                                         : block.id();
2009 
2010         if (!dominator_analysis.Dominates(dependency_block->id(),
2011                                           domination_target_id)) {
2012           return false;
2013         }
2014       }
2015     }
2016   }
2017 
2018   return true;
2019 }
2020 
GetFunctionIterator(opt::IRContext * ir_context,uint32_t function_id)2021 opt::Module::iterator GetFunctionIterator(opt::IRContext* ir_context,
2022                                           uint32_t function_id) {
2023   return std::find_if(ir_context->module()->begin(),
2024                       ir_context->module()->end(),
2025                       [function_id](const opt::Function& f) {
2026                         return f.result_id() == function_id;
2027                       });
2028 }
2029 
2030 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3582): Add all
2031 //  opcodes that are agnostic to signedness of operands to function.
2032 //  This is not exhaustive yet.
IsAgnosticToSignednessOfOperand(spv::Op opcode,uint32_t use_in_operand_index)2033 bool IsAgnosticToSignednessOfOperand(spv::Op opcode,
2034                                      uint32_t use_in_operand_index) {
2035   switch (opcode) {
2036     case spv::Op::OpSNegate:
2037     case spv::Op::OpNot:
2038     case spv::Op::OpIAdd:
2039     case spv::Op::OpISub:
2040     case spv::Op::OpIMul:
2041     case spv::Op::OpSDiv:
2042     case spv::Op::OpSRem:
2043     case spv::Op::OpSMod:
2044     case spv::Op::OpShiftRightLogical:
2045     case spv::Op::OpShiftRightArithmetic:
2046     case spv::Op::OpShiftLeftLogical:
2047     case spv::Op::OpBitwiseOr:
2048     case spv::Op::OpBitwiseXor:
2049     case spv::Op::OpBitwiseAnd:
2050     case spv::Op::OpIEqual:
2051     case spv::Op::OpINotEqual:
2052     case spv::Op::OpULessThan:
2053     case spv::Op::OpSLessThan:
2054     case spv::Op::OpUGreaterThan:
2055     case spv::Op::OpSGreaterThan:
2056     case spv::Op::OpULessThanEqual:
2057     case spv::Op::OpSLessThanEqual:
2058     case spv::Op::OpUGreaterThanEqual:
2059     case spv::Op::OpSGreaterThanEqual:
2060       return true;
2061 
2062     case spv::Op::OpAtomicStore:
2063     case spv::Op::OpAtomicExchange:
2064     case spv::Op::OpAtomicIAdd:
2065     case spv::Op::OpAtomicISub:
2066     case spv::Op::OpAtomicSMin:
2067     case spv::Op::OpAtomicUMin:
2068     case spv::Op::OpAtomicSMax:
2069     case spv::Op::OpAtomicUMax:
2070     case spv::Op::OpAtomicAnd:
2071     case spv::Op::OpAtomicOr:
2072     case spv::Op::OpAtomicXor:
2073     case spv::Op::OpAtomicFAddEXT:  // Capability AtomicFloat32AddEXT,
2074       // AtomicFloat64AddEXT.
2075       assert(use_in_operand_index != 0 &&
2076              "Signedness check should not occur on a pointer operand.");
2077       return use_in_operand_index == 1 || use_in_operand_index == 2;
2078 
2079     case spv::Op::OpAtomicCompareExchange:
2080     case spv::Op::OpAtomicCompareExchangeWeak:  // Capability Kernel.
2081       assert(use_in_operand_index != 0 &&
2082              "Signedness check should not occur on a pointer operand.");
2083       return use_in_operand_index >= 1 && use_in_operand_index <= 3;
2084 
2085     case spv::Op::OpAtomicLoad:
2086     case spv::Op::OpAtomicIIncrement:
2087     case spv::Op::OpAtomicIDecrement:
2088     case spv::Op::OpAtomicFlagTestAndSet:  // Capability Kernel.
2089     case spv::Op::OpAtomicFlagClear:       // Capability Kernel.
2090       assert(use_in_operand_index != 0 &&
2091              "Signedness check should not occur on a pointer operand.");
2092       return use_in_operand_index >= 1;
2093 
2094     case spv::Op::OpAccessChain:
2095       // The signedness of indices does not matter.
2096       return use_in_operand_index > 0;
2097 
2098     default:
2099       // Conservatively assume that the id cannot be swapped in other
2100       // instructions.
2101       return false;
2102   }
2103 }
2104 
TypesAreCompatible(opt::IRContext * ir_context,spv::Op opcode,uint32_t use_in_operand_index,uint32_t type_id_1,uint32_t type_id_2)2105 bool TypesAreCompatible(opt::IRContext* ir_context, spv::Op opcode,
2106                         uint32_t use_in_operand_index, uint32_t type_id_1,
2107                         uint32_t type_id_2) {
2108   assert(ir_context->get_type_mgr()->GetType(type_id_1) &&
2109          ir_context->get_type_mgr()->GetType(type_id_2) &&
2110          "Type ids are invalid");
2111 
2112   return type_id_1 == type_id_2 ||
2113          (IsAgnosticToSignednessOfOperand(opcode, use_in_operand_index) &&
2114           fuzzerutil::TypesAreEqualUpToSign(ir_context, type_id_1, type_id_2));
2115 }
2116 
2117 }  // namespace fuzzerutil
2118 }  // namespace fuzz
2119 }  // namespace spvtools
2120