1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/fuzzer_util.h"
16
17 #include <algorithm>
18 #include <unordered_set>
19
20 #include "source/opt/build_module.h"
21
22 namespace spvtools {
23 namespace fuzz {
24
25 namespace fuzzerutil {
26 namespace {
27
28 // A utility class that uses RAII to change and restore the terminator
29 // instruction of the |block|.
30 class ChangeTerminatorRAII {
31 public:
ChangeTerminatorRAII(opt::BasicBlock * block,opt::Instruction new_terminator)32 explicit ChangeTerminatorRAII(opt::BasicBlock* block,
33 opt::Instruction new_terminator)
34 : block_(block), old_terminator_(std::move(*block->terminator())) {
35 *block_->terminator() = std::move(new_terminator);
36 }
37
~ChangeTerminatorRAII()38 ~ChangeTerminatorRAII() {
39 *block_->terminator() = std::move(old_terminator_);
40 }
41
42 private:
43 opt::BasicBlock* block_;
44 opt::Instruction old_terminator_;
45 };
46
MaybeGetOpConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t type_id,bool is_irrelevant)47 uint32_t MaybeGetOpConstant(opt::IRContext* ir_context,
48 const TransformationContext& transformation_context,
49 const std::vector<uint32_t>& words,
50 uint32_t type_id, bool is_irrelevant) {
51 for (const auto& inst : ir_context->types_values()) {
52 if (inst.opcode() == spv::Op::OpConstant && inst.type_id() == type_id &&
53 inst.GetInOperand(0).words == words &&
54 transformation_context.GetFactManager()->IdIsIrrelevant(
55 inst.result_id()) == is_irrelevant) {
56 return inst.result_id();
57 }
58 }
59
60 return 0;
61 }
62
63 } // namespace
64
65 const spvtools::MessageConsumer kSilentMessageConsumer =
66 [](spv_message_level_t, const char*, const spv_position_t&,
__anon3879f0950202(spv_message_level_t, const char*, const spv_position_t&, const char*) 67 const char*) -> void {};
68
BuildIRContext(spv_target_env target_env,const spvtools::MessageConsumer & message_consumer,const std::vector<uint32_t> & binary_in,spv_validator_options validator_options,std::unique_ptr<spvtools::opt::IRContext> * ir_context)69 bool BuildIRContext(spv_target_env target_env,
70 const spvtools::MessageConsumer& message_consumer,
71 const std::vector<uint32_t>& binary_in,
72 spv_validator_options validator_options,
73 std::unique_ptr<spvtools::opt::IRContext>* ir_context) {
74 SpirvTools tools(target_env);
75 tools.SetMessageConsumer(message_consumer);
76 if (!tools.IsValid()) {
77 message_consumer(SPV_MSG_ERROR, nullptr, {},
78 "Failed to create SPIRV-Tools interface; stopping.");
79 return false;
80 }
81
82 // Initial binary should be valid.
83 if (!tools.Validate(binary_in.data(), binary_in.size(), validator_options)) {
84 message_consumer(SPV_MSG_ERROR, nullptr, {},
85 "Initial binary is invalid; stopping.");
86 return false;
87 }
88
89 // Build the module from the input binary.
90 auto result = BuildModule(target_env, message_consumer, binary_in.data(),
91 binary_in.size());
92 assert(result && "IRContext must be valid");
93 *ir_context = std::move(result);
94 return true;
95 }
96
IsFreshId(opt::IRContext * context,uint32_t id)97 bool IsFreshId(opt::IRContext* context, uint32_t id) {
98 return !context->get_def_use_mgr()->GetDef(id);
99 }
100
UpdateModuleIdBound(opt::IRContext * context,uint32_t id)101 void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
102 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
103 // case where the maximum id bound is reached.
104 context->module()->SetIdBound(
105 std::max(context->module()->id_bound(), id + 1));
106 }
107
MaybeFindBlock(opt::IRContext * context,uint32_t maybe_block_id)108 opt::BasicBlock* MaybeFindBlock(opt::IRContext* context,
109 uint32_t maybe_block_id) {
110 auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
111 if (inst == nullptr) {
112 // No instruction defining this id was found.
113 return nullptr;
114 }
115 if (inst->opcode() != spv::Op::OpLabel) {
116 // The instruction defining the id is not a label, so it cannot be a block
117 // id.
118 return nullptr;
119 }
120 return context->cfg()->block(maybe_block_id);
121 }
122
PhiIdsOkForNewEdge(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)123 bool PhiIdsOkForNewEdge(
124 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
125 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
126 if (bb_from->IsSuccessor(bb_to)) {
127 // There is already an edge from |from_block| to |to_block|, so there is
128 // no need to extend OpPhi instructions. Do not allow phi ids to be
129 // present. This might turn out to be too strict; perhaps it would be OK
130 // just to ignore the ids in this case.
131 return phi_ids.empty();
132 }
133 // The edge would add a previously non-existent edge from |from_block| to
134 // |to_block|, so we go through the given phi ids and check that they exactly
135 // match the OpPhi instructions in |to_block|.
136 uint32_t phi_index = 0;
137 // An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
138 // makes sense here because we need to increment |phi_index| for each OpPhi
139 // instruction.
140 for (auto& inst : *bb_to) {
141 if (inst.opcode() != spv::Op::OpPhi) {
142 // The OpPhi instructions all occur at the start of the block; if we find
143 // a non-OpPhi then we have seen them all.
144 break;
145 }
146 if (phi_index == static_cast<uint32_t>(phi_ids.size())) {
147 // Not enough phi ids have been provided to account for the OpPhi
148 // instructions.
149 return false;
150 }
151 // Look for an instruction defining the next phi id.
152 opt::Instruction* phi_extension =
153 context->get_def_use_mgr()->GetDef(phi_ids[phi_index]);
154 if (!phi_extension) {
155 // The id given to extend this OpPhi does not exist.
156 return false;
157 }
158 if (phi_extension->type_id() != inst.type_id()) {
159 // The instruction given to extend this OpPhi either does not have a type
160 // or its type does not match that of the OpPhi.
161 return false;
162 }
163
164 if (context->get_instr_block(phi_extension)) {
165 // The instruction defining the phi id has an associated block (i.e., it
166 // is not a global value). Check whether its definition dominates the
167 // exit of |from_block|.
168 auto dominator_analysis =
169 context->GetDominatorAnalysis(bb_from->GetParent());
170 if (!dominator_analysis->Dominates(phi_extension,
171 bb_from->terminator())) {
172 // The given id is no good as its definition does not dominate the exit
173 // of |from_block|
174 return false;
175 }
176 }
177 phi_index++;
178 }
179 // We allow some of the ids provided for extending OpPhi instructions to be
180 // unused. Their presence does no harm, and requiring a perfect match may
181 // make transformations less likely to cleanly apply.
182 return true;
183 }
184
CreateUnreachableEdgeInstruction(opt::IRContext * ir_context,uint32_t bb_from_id,uint32_t bb_to_id,uint32_t bool_id)185 opt::Instruction CreateUnreachableEdgeInstruction(opt::IRContext* ir_context,
186 uint32_t bb_from_id,
187 uint32_t bb_to_id,
188 uint32_t bool_id) {
189 const auto* bb_from = MaybeFindBlock(ir_context, bb_from_id);
190 assert(bb_from && "|bb_from_id| is invalid");
191 assert(MaybeFindBlock(ir_context, bb_to_id) && "|bb_to_id| is invalid");
192 assert(bb_from->terminator()->opcode() == spv::Op::OpBranch &&
193 "Precondition on terminator of bb_from is not satisfied");
194
195 // Get the id of the boolean constant to be used as the condition.
196 auto condition_inst = ir_context->get_def_use_mgr()->GetDef(bool_id);
197 assert(condition_inst &&
198 (condition_inst->opcode() == spv::Op::OpConstantTrue ||
199 condition_inst->opcode() == spv::Op::OpConstantFalse) &&
200 "|bool_id| is invalid");
201
202 auto condition_value = condition_inst->opcode() == spv::Op::OpConstantTrue;
203 auto successor_id = bb_from->terminator()->GetSingleWordInOperand(0);
204
205 // Add the dead branch, by turning OpBranch into OpBranchConditional, and
206 // ordering the targets depending on whether the given boolean corresponds to
207 // true or false.
208 return opt::Instruction(
209 ir_context, spv::Op::OpBranchConditional, 0, 0,
210 {{SPV_OPERAND_TYPE_ID, {bool_id}},
211 {SPV_OPERAND_TYPE_ID, {condition_value ? successor_id : bb_to_id}},
212 {SPV_OPERAND_TYPE_ID, {condition_value ? bb_to_id : successor_id}}});
213 }
214
AddUnreachableEdgeAndUpdateOpPhis(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,uint32_t bool_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)215 void AddUnreachableEdgeAndUpdateOpPhis(
216 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
217 uint32_t bool_id,
218 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
219 assert(PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) &&
220 "Precondition on phi_ids is not satisfied");
221
222 const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
223 *bb_from->terminator() = CreateUnreachableEdgeInstruction(
224 context, bb_from->id(), bb_to->id(), bool_id);
225
226 // Update OpPhi instructions in the target block if this branch adds a
227 // previously non-existent edge from source to target.
228 if (!from_to_edge_already_exists) {
229 uint32_t phi_index = 0;
230 for (auto& inst : *bb_to) {
231 if (inst.opcode() != spv::Op::OpPhi) {
232 break;
233 }
234 assert(phi_index < static_cast<uint32_t>(phi_ids.size()) &&
235 "There should be at least one phi id per OpPhi instruction.");
236 inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}});
237 inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}});
238 phi_index++;
239 }
240 }
241 }
242
BlockIsBackEdge(opt::IRContext * context,uint32_t block_id,uint32_t loop_header_id)243 bool BlockIsBackEdge(opt::IRContext* context, uint32_t block_id,
244 uint32_t loop_header_id) {
245 auto block = context->cfg()->block(block_id);
246 auto loop_header = context->cfg()->block(loop_header_id);
247
248 // |block| and |loop_header| must be defined, |loop_header| must be in fact
249 // loop header and |block| must branch to it.
250 if (!(block && loop_header && loop_header->IsLoopHeader() &&
251 block->IsSuccessor(loop_header))) {
252 return false;
253 }
254
255 // |block| must be reachable and be dominated by |loop_header|.
256 opt::DominatorAnalysis* dominator_analysis =
257 context->GetDominatorAnalysis(loop_header->GetParent());
258 return context->IsReachable(*block) &&
259 dominator_analysis->Dominates(loop_header, block);
260 }
261
BlockIsInLoopContinueConstruct(opt::IRContext * context,uint32_t block_id,uint32_t maybe_loop_header_id)262 bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
263 uint32_t maybe_loop_header_id) {
264 // We deem a block to be part of a loop's continue construct if the loop's
265 // continue target dominates the block.
266 auto containing_construct_block = context->cfg()->block(maybe_loop_header_id);
267 if (containing_construct_block->IsLoopHeader()) {
268 auto continue_target = containing_construct_block->ContinueBlockId();
269 if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
270 ->Dominates(continue_target, block_id)) {
271 return true;
272 }
273 }
274 return false;
275 }
276
GetIteratorForInstruction(opt::BasicBlock * block,const opt::Instruction * inst)277 opt::BasicBlock::iterator GetIteratorForInstruction(
278 opt::BasicBlock* block, const opt::Instruction* inst) {
279 for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
280 if (inst == &*inst_it) {
281 return inst_it;
282 }
283 }
284 return block->end();
285 }
286
CanInsertOpcodeBeforeInstruction(spv::Op opcode,const opt::BasicBlock::iterator & instruction_in_block)287 bool CanInsertOpcodeBeforeInstruction(
288 spv::Op opcode, const opt::BasicBlock::iterator& instruction_in_block) {
289 if (instruction_in_block->PreviousNode() &&
290 (instruction_in_block->PreviousNode()->opcode() == spv::Op::OpLoopMerge ||
291 instruction_in_block->PreviousNode()->opcode() ==
292 spv::Op::OpSelectionMerge)) {
293 // We cannot insert directly after a merge instruction.
294 return false;
295 }
296 if (opcode != spv::Op::OpVariable &&
297 instruction_in_block->opcode() == spv::Op::OpVariable) {
298 // We cannot insert a non-OpVariable instruction directly before a
299 // variable; variables in a function must be contiguous in the entry block.
300 return false;
301 }
302 // We cannot insert a non-OpPhi instruction directly before an OpPhi, because
303 // OpPhi instructions need to be contiguous at the start of a block.
304 return opcode == spv::Op::OpPhi ||
305 instruction_in_block->opcode() != spv::Op::OpPhi;
306 }
307
CanMakeSynonymOf(opt::IRContext * ir_context,const TransformationContext & transformation_context,const opt::Instruction & inst)308 bool CanMakeSynonymOf(opt::IRContext* ir_context,
309 const TransformationContext& transformation_context,
310 const opt::Instruction& inst) {
311 if (inst.opcode() == spv::Op::OpSampledImage) {
312 // The SPIR-V data rules say that only very specific instructions may
313 // may consume the result id of an OpSampledImage, and this excludes the
314 // instructions that are used for making synonyms.
315 return false;
316 }
317 if (!inst.HasResultId()) {
318 // We can only make a synonym of an instruction that generates an id.
319 return false;
320 }
321 if (transformation_context.GetFactManager()->IdIsIrrelevant(
322 inst.result_id())) {
323 // An irrelevant id can't be a synonym of anything.
324 return false;
325 }
326 if (!inst.type_id()) {
327 // We can only make a synonym of an instruction that has a type.
328 return false;
329 }
330 auto type_inst = ir_context->get_def_use_mgr()->GetDef(inst.type_id());
331 if (type_inst->opcode() == spv::Op::OpTypeVoid) {
332 // We only make synonyms of instructions that define objects, and an object
333 // cannot have void type.
334 return false;
335 }
336 if (type_inst->opcode() == spv::Op::OpTypePointer) {
337 switch (inst.opcode()) {
338 case spv::Op::OpConstantNull:
339 case spv::Op::OpUndef:
340 // We disallow making synonyms of null or undefined pointers. This is
341 // to provide the property that if the original shader exhibited no bad
342 // pointer accesses, the transformed shader will not either.
343 return false;
344 default:
345 break;
346 }
347 }
348
349 // We do not make synonyms of objects that have decorations: if the synonym is
350 // not decorated analogously, using the original object vs. its synonymous
351 // form may not be equivalent.
352 return ir_context->get_decoration_mgr()
353 ->GetDecorationsFor(inst.result_id(), true)
354 .empty();
355 }
356
IsCompositeType(const opt::analysis::Type * type)357 bool IsCompositeType(const opt::analysis::Type* type) {
358 return type && (type->AsArray() || type->AsMatrix() || type->AsStruct() ||
359 type->AsVector());
360 }
361
RepeatedFieldToVector(const google::protobuf::RepeatedField<uint32_t> & repeated_field)362 std::vector<uint32_t> RepeatedFieldToVector(
363 const google::protobuf::RepeatedField<uint32_t>& repeated_field) {
364 std::vector<uint32_t> result;
365 for (auto i : repeated_field) {
366 result.push_back(i);
367 }
368 return result;
369 }
370
WalkOneCompositeTypeIndex(opt::IRContext * context,uint32_t base_object_type_id,uint32_t index)371 uint32_t WalkOneCompositeTypeIndex(opt::IRContext* context,
372 uint32_t base_object_type_id,
373 uint32_t index) {
374 auto should_be_composite_type =
375 context->get_def_use_mgr()->GetDef(base_object_type_id);
376 assert(should_be_composite_type && "The type should exist.");
377 switch (should_be_composite_type->opcode()) {
378 case spv::Op::OpTypeArray: {
379 auto array_length = GetArraySize(*should_be_composite_type, context);
380 if (array_length == 0 || index >= array_length) {
381 return 0;
382 }
383 return should_be_composite_type->GetSingleWordInOperand(0);
384 }
385 case spv::Op::OpTypeMatrix:
386 case spv::Op::OpTypeVector: {
387 auto count = should_be_composite_type->GetSingleWordInOperand(1);
388 if (index >= count) {
389 return 0;
390 }
391 return should_be_composite_type->GetSingleWordInOperand(0);
392 }
393 case spv::Op::OpTypeStruct: {
394 if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
395 return 0;
396 }
397 return should_be_composite_type->GetSingleWordInOperand(index);
398 }
399 default:
400 return 0;
401 }
402 }
403
WalkCompositeTypeIndices(opt::IRContext * context,uint32_t base_object_type_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & indices)404 uint32_t WalkCompositeTypeIndices(
405 opt::IRContext* context, uint32_t base_object_type_id,
406 const google::protobuf::RepeatedField<google::protobuf::uint32>& indices) {
407 uint32_t sub_object_type_id = base_object_type_id;
408 for (auto index : indices) {
409 sub_object_type_id =
410 WalkOneCompositeTypeIndex(context, sub_object_type_id, index);
411 if (!sub_object_type_id) {
412 return 0;
413 }
414 }
415 return sub_object_type_id;
416 }
417
GetNumberOfStructMembers(const opt::Instruction & struct_type_instruction)418 uint32_t GetNumberOfStructMembers(
419 const opt::Instruction& struct_type_instruction) {
420 assert(struct_type_instruction.opcode() == spv::Op::OpTypeStruct &&
421 "An OpTypeStruct instruction is required here.");
422 return struct_type_instruction.NumInOperands();
423 }
424
GetArraySize(const opt::Instruction & array_type_instruction,opt::IRContext * context)425 uint32_t GetArraySize(const opt::Instruction& array_type_instruction,
426 opt::IRContext* context) {
427 auto array_length_constant =
428 context->get_constant_mgr()
429 ->GetConstantFromInst(context->get_def_use_mgr()->GetDef(
430 array_type_instruction.GetSingleWordInOperand(1)))
431 ->AsIntConstant();
432 if (array_length_constant->words().size() != 1) {
433 return 0;
434 }
435 return array_length_constant->GetU32();
436 }
437
GetBoundForCompositeIndex(const opt::Instruction & composite_type_inst,opt::IRContext * ir_context)438 uint32_t GetBoundForCompositeIndex(const opt::Instruction& composite_type_inst,
439 opt::IRContext* ir_context) {
440 switch (composite_type_inst.opcode()) {
441 case spv::Op::OpTypeArray:
442 return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
443 case spv::Op::OpTypeMatrix:
444 case spv::Op::OpTypeVector:
445 return composite_type_inst.GetSingleWordInOperand(1);
446 case spv::Op::OpTypeStruct: {
447 return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
448 }
449 case spv::Op::OpTypeRuntimeArray:
450 assert(false &&
451 "GetBoundForCompositeIndex should not be invoked with an "
452 "OpTypeRuntimeArray, which does not have a static bound.");
453 return 0;
454 default:
455 assert(false && "Unknown composite type.");
456 return 0;
457 }
458 }
459
GetMemorySemanticsForStorageClass(spv::StorageClass storage_class)460 spv::MemorySemanticsMask GetMemorySemanticsForStorageClass(
461 spv::StorageClass storage_class) {
462 switch (storage_class) {
463 case spv::StorageClass::Workgroup:
464 return spv::MemorySemanticsMask::WorkgroupMemory;
465
466 case spv::StorageClass::StorageBuffer:
467 case spv::StorageClass::PhysicalStorageBuffer:
468 return spv::MemorySemanticsMask::UniformMemory;
469
470 case spv::StorageClass::CrossWorkgroup:
471 return spv::MemorySemanticsMask::CrossWorkgroupMemory;
472
473 case spv::StorageClass::AtomicCounter:
474 return spv::MemorySemanticsMask::AtomicCounterMemory;
475
476 case spv::StorageClass::Image:
477 return spv::MemorySemanticsMask::ImageMemory;
478
479 default:
480 return spv::MemorySemanticsMask::MaskNone;
481 }
482 }
483
IsValid(const opt::IRContext * context,spv_validator_options validator_options,MessageConsumer consumer)484 bool IsValid(const opt::IRContext* context,
485 spv_validator_options validator_options,
486 MessageConsumer consumer) {
487 std::vector<uint32_t> binary;
488 context->module()->ToBinary(&binary, false);
489 SpirvTools tools(context->grammar().target_env());
490 tools.SetMessageConsumer(std::move(consumer));
491 return tools.Validate(binary.data(), binary.size(), validator_options);
492 }
493
IsValidAndWellFormed(const opt::IRContext * ir_context,spv_validator_options validator_options,MessageConsumer consumer)494 bool IsValidAndWellFormed(const opt::IRContext* ir_context,
495 spv_validator_options validator_options,
496 MessageConsumer consumer) {
497 if (!IsValid(ir_context, validator_options, consumer)) {
498 // Expression to dump |ir_context| to /data/temp/shader.spv:
499 // DumpShader(ir_context, "/data/temp/shader.spv")
500 consumer(SPV_MSG_INFO, nullptr, {},
501 "Module is invalid (set a breakpoint to inspect).");
502 return false;
503 }
504 // Check that all blocks in the module have appropriate parent functions.
505 for (auto& function : *ir_context->module()) {
506 for (auto& block : function) {
507 if (block.GetParent() == nullptr) {
508 std::stringstream ss;
509 ss << "Block " << block.id() << " has no parent; its parent should be "
510 << function.result_id() << " (set a breakpoint to inspect).";
511 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
512 return false;
513 }
514 if (block.GetParent() != &function) {
515 std::stringstream ss;
516 ss << "Block " << block.id() << " should have parent "
517 << function.result_id() << " but instead has parent "
518 << block.GetParent() << " (set a breakpoint to inspect).";
519 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
520 return false;
521 }
522 }
523 }
524
525 // Check that all instructions have distinct unique ids. We map each unique
526 // id to the first instruction it is observed to be associated with so that
527 // if we encounter a duplicate we have access to the previous instruction -
528 // this is a useful aid to debugging.
529 std::unordered_map<uint32_t, opt::Instruction*> unique_ids;
530 bool found_duplicate = false;
531 ir_context->module()->ForEachInst([&consumer, &found_duplicate, ir_context,
532 &unique_ids](opt::Instruction* inst) {
533 (void)ir_context; // Only used in an assertion; keep release-mode compilers
534 // happy.
535 assert(inst->context() == ir_context &&
536 "Instruction has wrong IR context.");
537 if (unique_ids.count(inst->unique_id()) != 0) {
538 consumer(SPV_MSG_INFO, nullptr, {},
539 "Two instructions have the same unique id (set a breakpoint to "
540 "inspect).");
541 found_duplicate = true;
542 }
543 unique_ids.insert({inst->unique_id(), inst});
544 });
545 return !found_duplicate;
546 }
547
CloneIRContext(opt::IRContext * context)548 std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) {
549 std::vector<uint32_t> binary;
550 context->module()->ToBinary(&binary, false);
551 return BuildModule(context->grammar().target_env(), nullptr, binary.data(),
552 binary.size());
553 }
554
IsNonFunctionTypeId(opt::IRContext * ir_context,uint32_t id)555 bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id) {
556 auto type = ir_context->get_type_mgr()->GetType(id);
557 return type && !type->AsFunction();
558 }
559
IsMergeOrContinue(opt::IRContext * ir_context,uint32_t block_id)560 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) {
561 bool result = false;
562 ir_context->get_def_use_mgr()->WhileEachUse(
563 block_id,
564 [&result](const opt::Instruction* use_instruction,
565 uint32_t /*unused*/) -> bool {
566 switch (use_instruction->opcode()) {
567 case spv::Op::OpLoopMerge:
568 case spv::Op::OpSelectionMerge:
569 result = true;
570 return false;
571 default:
572 return true;
573 }
574 });
575 return result;
576 }
577
GetLoopFromMergeBlock(opt::IRContext * ir_context,uint32_t merge_block_id)578 uint32_t GetLoopFromMergeBlock(opt::IRContext* ir_context,
579 uint32_t merge_block_id) {
580 uint32_t result = 0;
581 ir_context->get_def_use_mgr()->WhileEachUse(
582 merge_block_id,
583 [ir_context, &result](opt::Instruction* use_instruction,
584 uint32_t use_index) -> bool {
585 switch (use_instruction->opcode()) {
586 case spv::Op::OpLoopMerge:
587 // The merge block operand is the first operand in OpLoopMerge.
588 if (use_index == 0) {
589 result = ir_context->get_instr_block(use_instruction)->id();
590 return false;
591 }
592 return true;
593 default:
594 return true;
595 }
596 });
597 return result;
598 }
599
FindFunctionType(opt::IRContext * ir_context,const std::vector<uint32_t> & type_ids)600 uint32_t FindFunctionType(opt::IRContext* ir_context,
601 const std::vector<uint32_t>& type_ids) {
602 // Look through the existing types for a match.
603 for (auto& type_or_value : ir_context->types_values()) {
604 if (type_or_value.opcode() != spv::Op::OpTypeFunction) {
605 // We are only interested in function types.
606 continue;
607 }
608 if (type_or_value.NumInOperands() != type_ids.size()) {
609 // Not a match: different numbers of arguments.
610 continue;
611 }
612 // Check whether the return type and argument types match.
613 bool input_operands_match = true;
614 for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
615 if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
616 input_operands_match = false;
617 break;
618 }
619 }
620 if (input_operands_match) {
621 // Everything matches.
622 return type_or_value.result_id();
623 }
624 }
625 // No match was found.
626 return 0;
627 }
628
GetFunctionType(opt::IRContext * context,const opt::Function * function)629 opt::Instruction* GetFunctionType(opt::IRContext* context,
630 const opt::Function* function) {
631 uint32_t type_id = function->DefInst().GetSingleWordInOperand(1);
632 return context->get_def_use_mgr()->GetDef(type_id);
633 }
634
FindFunction(opt::IRContext * ir_context,uint32_t function_id)635 opt::Function* FindFunction(opt::IRContext* ir_context, uint32_t function_id) {
636 for (auto& function : *ir_context->module()) {
637 if (function.result_id() == function_id) {
638 return &function;
639 }
640 }
641 return nullptr;
642 }
643
FunctionContainsOpKillOrUnreachable(const opt::Function & function)644 bool FunctionContainsOpKillOrUnreachable(const opt::Function& function) {
645 for (auto& block : function) {
646 if (block.terminator()->opcode() == spv::Op::OpKill ||
647 block.terminator()->opcode() == spv::Op::OpUnreachable) {
648 return true;
649 }
650 }
651 return false;
652 }
653
FunctionIsEntryPoint(opt::IRContext * context,uint32_t function_id)654 bool FunctionIsEntryPoint(opt::IRContext* context, uint32_t function_id) {
655 for (auto& entry_point : context->module()->entry_points()) {
656 if (entry_point.GetSingleWordInOperand(1) == function_id) {
657 return true;
658 }
659 }
660 return false;
661 }
662
IdIsAvailableAtUse(opt::IRContext * context,opt::Instruction * use_instruction,uint32_t use_input_operand_index,uint32_t id)663 bool IdIsAvailableAtUse(opt::IRContext* context,
664 opt::Instruction* use_instruction,
665 uint32_t use_input_operand_index, uint32_t id) {
666 assert(context->get_instr_block(use_instruction) &&
667 "|use_instruction| must be in a basic block");
668
669 auto defining_instruction = context->get_def_use_mgr()->GetDef(id);
670 auto enclosing_function =
671 context->get_instr_block(use_instruction)->GetParent();
672 // If the id a function parameter, it needs to be associated with the
673 // function containing the use.
674 if (defining_instruction->opcode() == spv::Op::OpFunctionParameter) {
675 return InstructionIsFunctionParameter(defining_instruction,
676 enclosing_function);
677 }
678 if (!context->get_instr_block(id)) {
679 // The id must be at global scope.
680 return true;
681 }
682 if (defining_instruction == use_instruction) {
683 // It is not OK for a definition to use itself.
684 return false;
685 }
686 if (!context->IsReachable(*context->get_instr_block(use_instruction)) ||
687 !context->IsReachable(*context->get_instr_block(id))) {
688 // Skip unreachable blocks.
689 return false;
690 }
691 auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function);
692 if (use_instruction->opcode() == spv::Op::OpPhi) {
693 // In the case where the use is an operand to OpPhi, it is actually the
694 // *parent* block associated with the operand that must be dominated by
695 // the synonym.
696 auto parent_block =
697 use_instruction->GetSingleWordInOperand(use_input_operand_index + 1);
698 return dominator_analysis->Dominates(
699 context->get_instr_block(defining_instruction)->id(), parent_block);
700 }
701 return dominator_analysis->Dominates(defining_instruction, use_instruction);
702 }
703
IdIsAvailableBeforeInstruction(opt::IRContext * context,opt::Instruction * instruction,uint32_t id)704 bool IdIsAvailableBeforeInstruction(opt::IRContext* context,
705 opt::Instruction* instruction,
706 uint32_t id) {
707 assert(context->get_instr_block(instruction) &&
708 "|instruction| must be in a basic block");
709
710 auto id_definition = context->get_def_use_mgr()->GetDef(id);
711 auto function_enclosing_instruction =
712 context->get_instr_block(instruction)->GetParent();
713 // If the id a function parameter, it needs to be associated with the
714 // function containing the instruction.
715 if (id_definition->opcode() == spv::Op::OpFunctionParameter) {
716 return InstructionIsFunctionParameter(id_definition,
717 function_enclosing_instruction);
718 }
719 if (!context->get_instr_block(id)) {
720 // The id is at global scope.
721 return true;
722 }
723 if (id_definition == instruction) {
724 // The instruction is not available right before its own definition.
725 return false;
726 }
727 const auto* dominator_analysis =
728 context->GetDominatorAnalysis(function_enclosing_instruction);
729 if (context->IsReachable(*context->get_instr_block(instruction)) &&
730 context->IsReachable(*context->get_instr_block(id)) &&
731 dominator_analysis->Dominates(id_definition, instruction)) {
732 // The id's definition dominates the instruction, and both the definition
733 // and the instruction are in reachable blocks, thus the id is available at
734 // the instruction.
735 return true;
736 }
737 if (id_definition->opcode() == spv::Op::OpVariable &&
738 function_enclosing_instruction ==
739 context->get_instr_block(id)->GetParent()) {
740 assert(!context->IsReachable(*context->get_instr_block(instruction)) &&
741 "If the instruction were in a reachable block we should already "
742 "have returned true.");
743 // The id is a variable and it is in the same function as |instruction|.
744 // This is OK despite |instruction| being unreachable.
745 return true;
746 }
747 return false;
748 }
749
InstructionIsFunctionParameter(opt::Instruction * instruction,opt::Function * function)750 bool InstructionIsFunctionParameter(opt::Instruction* instruction,
751 opt::Function* function) {
752 if (instruction->opcode() != spv::Op::OpFunctionParameter) {
753 return false;
754 }
755 bool found_parameter = false;
756 function->ForEachParam(
757 [instruction, &found_parameter](opt::Instruction* param) {
758 if (param == instruction) {
759 found_parameter = true;
760 }
761 });
762 return found_parameter;
763 }
764
GetTypeId(opt::IRContext * context,uint32_t result_id)765 uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
766 const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
767 assert(inst && "|result_id| is invalid");
768 return inst->type_id();
769 }
770
GetPointeeTypeIdFromPointerType(opt::Instruction * pointer_type_inst)771 uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
772 assert(pointer_type_inst &&
773 pointer_type_inst->opcode() == spv::Op::OpTypePointer &&
774 "Precondition: |pointer_type_inst| must be OpTypePointer.");
775 return pointer_type_inst->GetSingleWordInOperand(1);
776 }
777
GetPointeeTypeIdFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)778 uint32_t GetPointeeTypeIdFromPointerType(opt::IRContext* context,
779 uint32_t pointer_type_id) {
780 return GetPointeeTypeIdFromPointerType(
781 context->get_def_use_mgr()->GetDef(pointer_type_id));
782 }
783
GetStorageClassFromPointerType(opt::Instruction * pointer_type_inst)784 spv::StorageClass GetStorageClassFromPointerType(
785 opt::Instruction* pointer_type_inst) {
786 assert(pointer_type_inst &&
787 pointer_type_inst->opcode() == spv::Op::OpTypePointer &&
788 "Precondition: |pointer_type_inst| must be OpTypePointer.");
789 return static_cast<spv::StorageClass>(
790 pointer_type_inst->GetSingleWordInOperand(0));
791 }
792
GetStorageClassFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)793 spv::StorageClass GetStorageClassFromPointerType(opt::IRContext* context,
794 uint32_t pointer_type_id) {
795 return GetStorageClassFromPointerType(
796 context->get_def_use_mgr()->GetDef(pointer_type_id));
797 }
798
MaybeGetPointerType(opt::IRContext * context,uint32_t pointee_type_id,spv::StorageClass storage_class)799 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
800 spv::StorageClass storage_class) {
801 for (auto& inst : context->types_values()) {
802 switch (inst.opcode()) {
803 case spv::Op::OpTypePointer:
804 if (spv::StorageClass(inst.GetSingleWordInOperand(0)) ==
805 storage_class &&
806 inst.GetSingleWordInOperand(1) == pointee_type_id) {
807 return inst.result_id();
808 }
809 break;
810 default:
811 break;
812 }
813 }
814 return 0;
815 }
816
InOperandIndexFromOperandIndex(const opt::Instruction & inst,uint32_t absolute_index)817 uint32_t InOperandIndexFromOperandIndex(const opt::Instruction& inst,
818 uint32_t absolute_index) {
819 // Subtract the number of non-input operands from the index
820 return absolute_index - inst.NumOperands() + inst.NumInOperands();
821 }
822
IsNullConstantSupported(opt::IRContext * ir_context,const opt::Instruction & type_inst)823 bool IsNullConstantSupported(opt::IRContext* ir_context,
824 const opt::Instruction& type_inst) {
825 switch (type_inst.opcode()) {
826 case spv::Op::OpTypeArray:
827 case spv::Op::OpTypeBool:
828 case spv::Op::OpTypeDeviceEvent:
829 case spv::Op::OpTypeEvent:
830 case spv::Op::OpTypeFloat:
831 case spv::Op::OpTypeInt:
832 case spv::Op::OpTypeMatrix:
833 case spv::Op::OpTypeQueue:
834 case spv::Op::OpTypeReserveId:
835 case spv::Op::OpTypeVector:
836 case spv::Op::OpTypeStruct:
837 return true;
838 case spv::Op::OpTypePointer:
839 // Null pointers are allowed if the VariablePointers capability is
840 // enabled, or if the VariablePointersStorageBuffer capability is enabled
841 // and the pointer type has StorageBuffer as its storage class.
842 if (ir_context->get_feature_mgr()->HasCapability(
843 spv::Capability::VariablePointers)) {
844 return true;
845 }
846 if (ir_context->get_feature_mgr()->HasCapability(
847 spv::Capability::VariablePointersStorageBuffer)) {
848 return spv::StorageClass(type_inst.GetSingleWordInOperand(0)) ==
849 spv::StorageClass::StorageBuffer;
850 }
851 return false;
852 default:
853 return false;
854 }
855 }
856
GlobalVariablesMustBeDeclaredInEntryPointInterfaces(const opt::IRContext * ir_context)857 bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
858 const opt::IRContext* ir_context) {
859 // TODO(afd): We capture the environments for which this requirement holds.
860 // The check should be refined on demand for other target environments.
861 switch (ir_context->grammar().target_env()) {
862 case SPV_ENV_UNIVERSAL_1_0:
863 case SPV_ENV_UNIVERSAL_1_1:
864 case SPV_ENV_UNIVERSAL_1_2:
865 case SPV_ENV_UNIVERSAL_1_3:
866 case SPV_ENV_VULKAN_1_0:
867 case SPV_ENV_VULKAN_1_1:
868 return false;
869 default:
870 return true;
871 }
872 }
873
AddVariableIdToEntryPointInterfaces(opt::IRContext * context,uint32_t id)874 void AddVariableIdToEntryPointInterfaces(opt::IRContext* context, uint32_t id) {
875 if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(context)) {
876 // Conservatively add this global to the interface of every entry point in
877 // the module. This means that the global is available for other
878 // transformations to use.
879 //
880 // A downside of this is that the global will be in the interface even if it
881 // ends up never being used.
882 //
883 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit
884 // this if a more thorough approach to entry point interfaces is taken.
885 for (auto& entry_point : context->module()->entry_points()) {
886 entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {id}});
887 }
888 }
889 }
890
AddGlobalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,spv::StorageClass storage_class,uint32_t initializer_id)891 opt::Instruction* AddGlobalVariable(opt::IRContext* context, uint32_t result_id,
892 uint32_t type_id,
893 spv::StorageClass storage_class,
894 uint32_t initializer_id) {
895 // Check various preconditions.
896 assert(result_id != 0 && "Result id can't be 0");
897
898 assert((storage_class == spv::StorageClass::Private ||
899 storage_class == spv::StorageClass::Workgroup) &&
900 "Variable's storage class must be either Private or Workgroup");
901
902 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
903 (void)type_inst; // Variable becomes unused in release mode.
904 assert(type_inst && type_inst->opcode() == spv::Op::OpTypePointer &&
905 GetStorageClassFromPointerType(type_inst) == storage_class &&
906 "Variable's type is invalid");
907
908 if (storage_class == spv::StorageClass::Workgroup) {
909 assert(initializer_id == 0);
910 }
911
912 if (initializer_id != 0) {
913 const auto* constant_inst =
914 context->get_def_use_mgr()->GetDef(initializer_id);
915 (void)constant_inst; // Variable becomes unused in release mode.
916 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
917 GetPointeeTypeIdFromPointerType(type_inst) ==
918 constant_inst->type_id() &&
919 "Initializer is invalid");
920 }
921
922 opt::Instruction::OperandList operands = {
923 {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast<uint32_t>(storage_class)}}};
924
925 if (initializer_id) {
926 operands.push_back({SPV_OPERAND_TYPE_ID, {initializer_id}});
927 }
928
929 auto new_instruction = MakeUnique<opt::Instruction>(
930 context, spv::Op::OpVariable, type_id, result_id, std::move(operands));
931 auto result = new_instruction.get();
932 context->module()->AddGlobalValue(std::move(new_instruction));
933
934 AddVariableIdToEntryPointInterfaces(context, result_id);
935 UpdateModuleIdBound(context, result_id);
936
937 return result;
938 }
939
AddLocalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,uint32_t function_id,uint32_t initializer_id)940 opt::Instruction* AddLocalVariable(opt::IRContext* context, uint32_t result_id,
941 uint32_t type_id, uint32_t function_id,
942 uint32_t initializer_id) {
943 // Check various preconditions.
944 assert(result_id != 0 && "Result id can't be 0");
945
946 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
947 (void)type_inst; // Variable becomes unused in release mode.
948 assert(type_inst && type_inst->opcode() == spv::Op::OpTypePointer &&
949 GetStorageClassFromPointerType(type_inst) ==
950 spv::StorageClass::Function &&
951 "Variable's type is invalid");
952
953 const auto* constant_inst =
954 context->get_def_use_mgr()->GetDef(initializer_id);
955 (void)constant_inst; // Variable becomes unused in release mode.
956 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
957 GetPointeeTypeIdFromPointerType(type_inst) ==
958 constant_inst->type_id() &&
959 "Initializer is invalid");
960
961 auto* function = FindFunction(context, function_id);
962 assert(function && "Function id is invalid");
963
964 auto new_instruction = MakeUnique<opt::Instruction>(
965 context, spv::Op::OpVariable, type_id, result_id,
966 opt::Instruction::OperandList{{SPV_OPERAND_TYPE_STORAGE_CLASS,
967 {uint32_t(spv::StorageClass::Function)}},
968 {SPV_OPERAND_TYPE_ID, {initializer_id}}});
969 auto result = new_instruction.get();
970 function->begin()->begin()->InsertBefore(std::move(new_instruction));
971
972 UpdateModuleIdBound(context, result_id);
973
974 return result;
975 }
976
HasDuplicates(const std::vector<uint32_t> & arr)977 bool HasDuplicates(const std::vector<uint32_t>& arr) {
978 return std::unordered_set<uint32_t>(arr.begin(), arr.end()).size() !=
979 arr.size();
980 }
981
IsPermutationOfRange(const std::vector<uint32_t> & arr,uint32_t lo,uint32_t hi)982 bool IsPermutationOfRange(const std::vector<uint32_t>& arr, uint32_t lo,
983 uint32_t hi) {
984 if (arr.empty()) {
985 return lo > hi;
986 }
987
988 if (HasDuplicates(arr)) {
989 return false;
990 }
991
992 auto min_max = std::minmax_element(arr.begin(), arr.end());
993 return arr.size() == hi - lo + 1 && *min_max.first == lo &&
994 *min_max.second == hi;
995 }
996
GetParameters(opt::IRContext * ir_context,uint32_t function_id)997 std::vector<opt::Instruction*> GetParameters(opt::IRContext* ir_context,
998 uint32_t function_id) {
999 auto* function = FindFunction(ir_context, function_id);
1000 assert(function && "|function_id| is invalid");
1001
1002 std::vector<opt::Instruction*> result;
1003 function->ForEachParam(
1004 [&result](opt::Instruction* inst) { result.push_back(inst); });
1005
1006 return result;
1007 }
1008
RemoveParameter(opt::IRContext * ir_context,uint32_t parameter_id)1009 void RemoveParameter(opt::IRContext* ir_context, uint32_t parameter_id) {
1010 auto* function = GetFunctionFromParameterId(ir_context, parameter_id);
1011 assert(function && "|parameter_id| is invalid");
1012 assert(!FunctionIsEntryPoint(ir_context, function->result_id()) &&
1013 "Can't remove parameter from an entry point function");
1014
1015 function->RemoveParameter(parameter_id);
1016
1017 // We've just removed parameters from the function and cleared their memory.
1018 // Make sure analyses have no dangling pointers.
1019 ir_context->InvalidateAnalysesExceptFor(
1020 opt::IRContext::Analysis::kAnalysisNone);
1021 }
1022
GetCallers(opt::IRContext * ir_context,uint32_t function_id)1023 std::vector<opt::Instruction*> GetCallers(opt::IRContext* ir_context,
1024 uint32_t function_id) {
1025 assert(FindFunction(ir_context, function_id) &&
1026 "|function_id| is not a result id of a function");
1027
1028 std::vector<opt::Instruction*> result;
1029 ir_context->get_def_use_mgr()->ForEachUser(
1030 function_id, [&result, function_id](opt::Instruction* inst) {
1031 if (inst->opcode() == spv::Op::OpFunctionCall &&
1032 inst->GetSingleWordInOperand(0) == function_id) {
1033 result.push_back(inst);
1034 }
1035 });
1036
1037 return result;
1038 }
1039
GetFunctionFromParameterId(opt::IRContext * ir_context,uint32_t param_id)1040 opt::Function* GetFunctionFromParameterId(opt::IRContext* ir_context,
1041 uint32_t param_id) {
1042 auto* param_inst = ir_context->get_def_use_mgr()->GetDef(param_id);
1043 assert(param_inst && "Parameter id is invalid");
1044
1045 for (auto& function : *ir_context->module()) {
1046 if (InstructionIsFunctionParameter(param_inst, &function)) {
1047 return &function;
1048 }
1049 }
1050
1051 return nullptr;
1052 }
1053
UpdateFunctionType(opt::IRContext * ir_context,uint32_t function_id,uint32_t new_function_type_result_id,uint32_t return_type_id,const std::vector<uint32_t> & parameter_type_ids)1054 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
1055 uint32_t new_function_type_result_id,
1056 uint32_t return_type_id,
1057 const std::vector<uint32_t>& parameter_type_ids) {
1058 // Check some initial constraints.
1059 assert(ir_context->get_type_mgr()->GetType(return_type_id) &&
1060 "Return type is invalid");
1061 for (auto id : parameter_type_ids) {
1062 const auto* type = ir_context->get_type_mgr()->GetType(id);
1063 (void)type; // Make compilers happy in release mode.
1064 // Parameters can't be OpTypeVoid.
1065 assert(type && !type->AsVoid() && "Parameter has invalid type");
1066 }
1067
1068 auto* function = FindFunction(ir_context, function_id);
1069 assert(function && "|function_id| is invalid");
1070
1071 auto* old_function_type = GetFunctionType(ir_context, function);
1072 assert(old_function_type && "Function has invalid type");
1073
1074 std::vector<uint32_t> operand_ids = {return_type_id};
1075 operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
1076 parameter_type_ids.end());
1077
1078 // A trivial case - we change nothing.
1079 if (FindFunctionType(ir_context, operand_ids) ==
1080 old_function_type->result_id()) {
1081 return old_function_type->result_id();
1082 }
1083
1084 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
1085 FindFunctionType(ir_context, operand_ids) == 0) {
1086 // We can change |old_function_type| only if it's used once in the module
1087 // and we are certain we won't create a duplicate as a result of the change.
1088
1089 // Update |old_function_type| in-place.
1090 opt::Instruction::OperandList operands;
1091 for (auto id : operand_ids) {
1092 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1093 }
1094
1095 old_function_type->SetInOperands(std::move(operands));
1096
1097 // |operands| may depend on result ids defined below the |old_function_type|
1098 // in the module.
1099 old_function_type->RemoveFromList();
1100 ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
1101 return old_function_type->result_id();
1102 } else {
1103 // We can't modify the |old_function_type| so we have to either use an
1104 // existing one or create a new one.
1105 auto type_id = FindOrCreateFunctionType(
1106 ir_context, new_function_type_result_id, operand_ids);
1107 assert(type_id != old_function_type->result_id() &&
1108 "We should've handled this case above");
1109
1110 function->DefInst().SetInOperand(1, {type_id});
1111
1112 // DefUseManager hasn't been updated yet, so if the following condition is
1113 // true, then |old_function_type| will have no users when this function
1114 // returns. We might as well remove it.
1115 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
1116 ir_context->KillInst(old_function_type);
1117 }
1118
1119 return type_id;
1120 }
1121 }
1122
AddFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1123 void AddFunctionType(opt::IRContext* ir_context, uint32_t result_id,
1124 const std::vector<uint32_t>& type_ids) {
1125 assert(result_id != 0 && "Result id can't be 0");
1126 assert(!type_ids.empty() &&
1127 "OpTypeFunction always has at least one operand - function's return "
1128 "type");
1129 assert(IsNonFunctionTypeId(ir_context, type_ids[0]) &&
1130 "Return type must not be a function");
1131
1132 for (size_t i = 1; i < type_ids.size(); ++i) {
1133 const auto* param_type = ir_context->get_type_mgr()->GetType(type_ids[i]);
1134 (void)param_type; // Make compiler happy in release mode.
1135 assert(param_type && !param_type->AsVoid() && !param_type->AsFunction() &&
1136 "Function parameter can't have a function or void type");
1137 }
1138
1139 opt::Instruction::OperandList operands;
1140 operands.reserve(type_ids.size());
1141 for (auto id : type_ids) {
1142 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1143 }
1144
1145 ir_context->AddType(MakeUnique<opt::Instruction>(
1146 ir_context, spv::Op::OpTypeFunction, 0, result_id, std::move(operands)));
1147
1148 UpdateModuleIdBound(ir_context, result_id);
1149 }
1150
FindOrCreateFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1151 uint32_t FindOrCreateFunctionType(opt::IRContext* ir_context,
1152 uint32_t result_id,
1153 const std::vector<uint32_t>& type_ids) {
1154 if (auto existing_id = FindFunctionType(ir_context, type_ids)) {
1155 return existing_id;
1156 }
1157 AddFunctionType(ir_context, result_id, type_ids);
1158 return result_id;
1159 }
1160
MaybeGetIntegerType(opt::IRContext * ir_context,uint32_t width,bool is_signed)1161 uint32_t MaybeGetIntegerType(opt::IRContext* ir_context, uint32_t width,
1162 bool is_signed) {
1163 opt::analysis::Integer type(width, is_signed);
1164 return ir_context->get_type_mgr()->GetId(&type);
1165 }
1166
MaybeGetFloatType(opt::IRContext * ir_context,uint32_t width)1167 uint32_t MaybeGetFloatType(opt::IRContext* ir_context, uint32_t width) {
1168 opt::analysis::Float type(width);
1169 return ir_context->get_type_mgr()->GetId(&type);
1170 }
1171
MaybeGetBoolType(opt::IRContext * ir_context)1172 uint32_t MaybeGetBoolType(opt::IRContext* ir_context) {
1173 opt::analysis::Bool type;
1174 return ir_context->get_type_mgr()->GetId(&type);
1175 }
1176
MaybeGetVectorType(opt::IRContext * ir_context,uint32_t component_type_id,uint32_t element_count)1177 uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
1178 uint32_t component_type_id,
1179 uint32_t element_count) {
1180 const auto* component_type =
1181 ir_context->get_type_mgr()->GetType(component_type_id);
1182 assert(component_type &&
1183 (component_type->AsInteger() || component_type->AsFloat() ||
1184 component_type->AsBool()) &&
1185 "|component_type_id| is invalid");
1186 assert(element_count >= 2 && element_count <= 4 &&
1187 "Precondition: component count must be in range [2, 4].");
1188 opt::analysis::Vector type(component_type, element_count);
1189 return ir_context->get_type_mgr()->GetId(&type);
1190 }
1191
MaybeGetStructType(opt::IRContext * ir_context,const std::vector<uint32_t> & component_type_ids)1192 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
1193 const std::vector<uint32_t>& component_type_ids) {
1194 for (auto& type_or_value : ir_context->types_values()) {
1195 if (type_or_value.opcode() != spv::Op::OpTypeStruct ||
1196 type_or_value.NumInOperands() !=
1197 static_cast<uint32_t>(component_type_ids.size())) {
1198 continue;
1199 }
1200 bool all_components_match = true;
1201 for (uint32_t i = 0; i < component_type_ids.size(); i++) {
1202 if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
1203 all_components_match = false;
1204 break;
1205 }
1206 }
1207 if (all_components_match) {
1208 return type_or_value.result_id();
1209 }
1210 }
1211 return 0;
1212 }
1213
MaybeGetVoidType(opt::IRContext * ir_context)1214 uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
1215 opt::analysis::Void type;
1216 return ir_context->get_type_mgr()->GetId(&type);
1217 }
1218
MaybeGetZeroConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,uint32_t scalar_or_composite_type_id,bool is_irrelevant)1219 uint32_t MaybeGetZeroConstant(
1220 opt::IRContext* ir_context,
1221 const TransformationContext& transformation_context,
1222 uint32_t scalar_or_composite_type_id, bool is_irrelevant) {
1223 const auto* type_inst =
1224 ir_context->get_def_use_mgr()->GetDef(scalar_or_composite_type_id);
1225 assert(type_inst && "|scalar_or_composite_type_id| is invalid");
1226
1227 switch (type_inst->opcode()) {
1228 case spv::Op::OpTypeBool:
1229 return MaybeGetBoolConstant(ir_context, transformation_context, false,
1230 is_irrelevant);
1231 case spv::Op::OpTypeFloat:
1232 case spv::Op::OpTypeInt: {
1233 const auto width = type_inst->GetSingleWordInOperand(0);
1234 std::vector<uint32_t> words = {0};
1235 if (width > 32) {
1236 words.push_back(0);
1237 }
1238
1239 return MaybeGetScalarConstant(ir_context, transformation_context, words,
1240 scalar_or_composite_type_id, is_irrelevant);
1241 }
1242 case spv::Op::OpTypeStruct: {
1243 std::vector<uint32_t> component_ids;
1244 for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
1245 const auto component_type_id = type_inst->GetSingleWordInOperand(i);
1246
1247 auto component_id =
1248 MaybeGetZeroConstant(ir_context, transformation_context,
1249 component_type_id, is_irrelevant);
1250
1251 if (component_id == 0 && is_irrelevant) {
1252 // Irrelevant constants can use either relevant or irrelevant
1253 // constituents.
1254 component_id = MaybeGetZeroConstant(
1255 ir_context, transformation_context, component_type_id, false);
1256 }
1257
1258 if (component_id == 0) {
1259 return 0;
1260 }
1261
1262 component_ids.push_back(component_id);
1263 }
1264
1265 return MaybeGetCompositeConstant(
1266 ir_context, transformation_context, component_ids,
1267 scalar_or_composite_type_id, is_irrelevant);
1268 }
1269 case spv::Op::OpTypeMatrix:
1270 case spv::Op::OpTypeVector: {
1271 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1272
1273 auto component_id = MaybeGetZeroConstant(
1274 ir_context, transformation_context, component_type_id, is_irrelevant);
1275
1276 if (component_id == 0 && is_irrelevant) {
1277 // Irrelevant constants can use either relevant or irrelevant
1278 // constituents.
1279 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1280 component_type_id, false);
1281 }
1282
1283 if (component_id == 0) {
1284 return 0;
1285 }
1286
1287 const auto component_count = type_inst->GetSingleWordInOperand(1);
1288 return MaybeGetCompositeConstant(
1289 ir_context, transformation_context,
1290 std::vector<uint32_t>(component_count, component_id),
1291 scalar_or_composite_type_id, is_irrelevant);
1292 }
1293 case spv::Op::OpTypeArray: {
1294 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1295
1296 auto component_id = MaybeGetZeroConstant(
1297 ir_context, transformation_context, component_type_id, is_irrelevant);
1298
1299 if (component_id == 0 && is_irrelevant) {
1300 // Irrelevant constants can use either relevant or irrelevant
1301 // constituents.
1302 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1303 component_type_id, false);
1304 }
1305
1306 if (component_id == 0) {
1307 return 0;
1308 }
1309
1310 return MaybeGetCompositeConstant(
1311 ir_context, transformation_context,
1312 std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
1313 component_id),
1314 scalar_or_composite_type_id, is_irrelevant);
1315 }
1316 default:
1317 assert(false && "Type is not supported");
1318 return 0;
1319 }
1320 }
1321
CanCreateConstant(opt::IRContext * ir_context,uint32_t type_id)1322 bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
1323 opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
1324 assert(type_instr != nullptr && "The type must exist.");
1325 assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
1326 "A type-generating opcode was expected.");
1327 switch (type_instr->opcode()) {
1328 case spv::Op::OpTypeBool:
1329 case spv::Op::OpTypeInt:
1330 case spv::Op::OpTypeFloat:
1331 case spv::Op::OpTypeMatrix:
1332 case spv::Op::OpTypeVector:
1333 return true;
1334 case spv::Op::OpTypeArray:
1335 return CanCreateConstant(ir_context,
1336 type_instr->GetSingleWordInOperand(0));
1337 case spv::Op::OpTypeStruct:
1338 if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
1339 return false;
1340 }
1341 for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
1342 if (!CanCreateConstant(ir_context,
1343 type_instr->GetSingleWordInOperand(index))) {
1344 return false;
1345 }
1346 }
1347 return true;
1348 default:
1349 return false;
1350 }
1351 }
1352
MaybeGetScalarConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t scalar_type_id,bool is_irrelevant)1353 uint32_t MaybeGetScalarConstant(
1354 opt::IRContext* ir_context,
1355 const TransformationContext& transformation_context,
1356 const std::vector<uint32_t>& words, uint32_t scalar_type_id,
1357 bool is_irrelevant) {
1358 const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
1359 assert(type && "|scalar_type_id| is invalid");
1360
1361 if (const auto* int_type = type->AsInteger()) {
1362 return MaybeGetIntegerConstant(ir_context, transformation_context, words,
1363 int_type->width(), int_type->IsSigned(),
1364 is_irrelevant);
1365 } else if (const auto* float_type = type->AsFloat()) {
1366 return MaybeGetFloatConstant(ir_context, transformation_context, words,
1367 float_type->width(), is_irrelevant);
1368 } else {
1369 assert(type->AsBool() && words.size() == 1 &&
1370 "|scalar_type_id| doesn't represent a scalar type");
1371 return MaybeGetBoolConstant(ir_context, transformation_context, words[0],
1372 is_irrelevant);
1373 }
1374 }
1375
MaybeGetCompositeConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & component_ids,uint32_t composite_type_id,bool is_irrelevant)1376 uint32_t MaybeGetCompositeConstant(
1377 opt::IRContext* ir_context,
1378 const TransformationContext& transformation_context,
1379 const std::vector<uint32_t>& component_ids, uint32_t composite_type_id,
1380 bool is_irrelevant) {
1381 const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
1382 (void)type; // Make compilers happy in release mode.
1383 assert(IsCompositeType(type) && "|composite_type_id| is invalid");
1384
1385 for (const auto& inst : ir_context->types_values()) {
1386 if (inst.opcode() == spv::Op::OpConstantComposite &&
1387 inst.type_id() == composite_type_id &&
1388 transformation_context.GetFactManager()->IdIsIrrelevant(
1389 inst.result_id()) == is_irrelevant &&
1390 inst.NumInOperands() == component_ids.size()) {
1391 bool is_match = true;
1392
1393 for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
1394 if (inst.GetSingleWordInOperand(i) != component_ids[i]) {
1395 is_match = false;
1396 break;
1397 }
1398 }
1399
1400 if (is_match) {
1401 return inst.result_id();
1402 }
1403 }
1404 }
1405
1406 return 0;
1407 }
1408
MaybeGetIntegerConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_signed,bool is_irrelevant)1409 uint32_t MaybeGetIntegerConstant(
1410 opt::IRContext* ir_context,
1411 const TransformationContext& transformation_context,
1412 const std::vector<uint32_t>& words, uint32_t width, bool is_signed,
1413 bool is_irrelevant) {
1414 if (auto type_id = MaybeGetIntegerType(ir_context, width, is_signed)) {
1415 return MaybeGetOpConstant(ir_context, transformation_context, words,
1416 type_id, is_irrelevant);
1417 }
1418
1419 return 0;
1420 }
1421
MaybeGetIntegerConstantFromValueAndType(opt::IRContext * ir_context,uint32_t value,uint32_t int_type_id)1422 uint32_t MaybeGetIntegerConstantFromValueAndType(opt::IRContext* ir_context,
1423 uint32_t value,
1424 uint32_t int_type_id) {
1425 auto int_type_inst = ir_context->get_def_use_mgr()->GetDef(int_type_id);
1426
1427 assert(int_type_inst && "The given type id must exist.");
1428
1429 auto int_type = ir_context->get_type_mgr()
1430 ->GetType(int_type_inst->result_id())
1431 ->AsInteger();
1432
1433 assert(int_type && int_type->width() == 32 &&
1434 "The given type id must correspond to an 32-bit integer type.");
1435
1436 opt::analysis::IntConstant constant(int_type, {value});
1437
1438 // Check that the constant exists in the module.
1439 if (!ir_context->get_constant_mgr()->FindConstant(&constant)) {
1440 return 0;
1441 }
1442
1443 return ir_context->get_constant_mgr()
1444 ->GetDefiningInstruction(&constant)
1445 ->result_id();
1446 }
1447
MaybeGetFloatConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_irrelevant)1448 uint32_t MaybeGetFloatConstant(
1449 opt::IRContext* ir_context,
1450 const TransformationContext& transformation_context,
1451 const std::vector<uint32_t>& words, uint32_t width, bool is_irrelevant) {
1452 if (auto type_id = MaybeGetFloatType(ir_context, width)) {
1453 return MaybeGetOpConstant(ir_context, transformation_context, words,
1454 type_id, is_irrelevant);
1455 }
1456
1457 return 0;
1458 }
1459
MaybeGetBoolConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,bool value,bool is_irrelevant)1460 uint32_t MaybeGetBoolConstant(
1461 opt::IRContext* ir_context,
1462 const TransformationContext& transformation_context, bool value,
1463 bool is_irrelevant) {
1464 if (auto type_id = MaybeGetBoolType(ir_context)) {
1465 for (const auto& inst : ir_context->types_values()) {
1466 if (inst.opcode() ==
1467 (value ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse) &&
1468 inst.type_id() == type_id &&
1469 transformation_context.GetFactManager()->IdIsIrrelevant(
1470 inst.result_id()) == is_irrelevant) {
1471 return inst.result_id();
1472 }
1473 }
1474 }
1475
1476 return 0;
1477 }
1478
IntToWords(uint64_t value,uint32_t width,bool is_signed)1479 std::vector<uint32_t> IntToWords(uint64_t value, uint32_t width,
1480 bool is_signed) {
1481 assert(width <= 64 && "The bit width should not be more than 64 bits");
1482
1483 // Sign-extend or zero-extend the last |width| bits of |value|, depending on
1484 // |is_signed|.
1485 if (is_signed) {
1486 // Sign-extend by shifting left and then shifting right, interpreting the
1487 // integer as signed.
1488 value = static_cast<int64_t>(value << (64 - width)) >> (64 - width);
1489 } else {
1490 // Zero-extend by shifting left and then shifting right, interpreting the
1491 // integer as unsigned.
1492 value = (value << (64 - width)) >> (64 - width);
1493 }
1494
1495 std::vector<uint32_t> result;
1496 result.push_back(static_cast<uint32_t>(value));
1497 if (width > 32) {
1498 result.push_back(static_cast<uint32_t>(value >> 32));
1499 }
1500 return result;
1501 }
1502
TypesAreEqualUpToSign(opt::IRContext * ir_context,uint32_t type1_id,uint32_t type2_id)1503 bool TypesAreEqualUpToSign(opt::IRContext* ir_context, uint32_t type1_id,
1504 uint32_t type2_id) {
1505 if (type1_id == type2_id) {
1506 return true;
1507 }
1508
1509 auto type1 = ir_context->get_type_mgr()->GetType(type1_id);
1510 auto type2 = ir_context->get_type_mgr()->GetType(type2_id);
1511
1512 // Integer scalar types must have the same width
1513 if (type1->AsInteger() && type2->AsInteger()) {
1514 return type1->AsInteger()->width() == type2->AsInteger()->width();
1515 }
1516
1517 // Integer vector types must have the same number of components and their
1518 // component types must be integers with the same width.
1519 if (type1->AsVector() && type2->AsVector()) {
1520 auto component_type1 = type1->AsVector()->element_type()->AsInteger();
1521 auto component_type2 = type2->AsVector()->element_type()->AsInteger();
1522
1523 // Only check the component count and width if they are integer.
1524 if (component_type1 && component_type2) {
1525 return type1->AsVector()->element_count() ==
1526 type2->AsVector()->element_count() &&
1527 component_type1->width() == component_type2->width();
1528 }
1529 }
1530
1531 // In all other cases, the types cannot be considered equal.
1532 return false;
1533 }
1534
RepeatedUInt32PairToMap(const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> & data)1535 std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
1536 const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>& data) {
1537 std::map<uint32_t, uint32_t> result;
1538
1539 for (const auto& entry : data) {
1540 result[entry.first()] = entry.second();
1541 }
1542
1543 return result;
1544 }
1545
1546 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t,uint32_t> & data)1547 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
1548 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> result;
1549
1550 for (const auto& entry : data) {
1551 protobufs::UInt32Pair pair;
1552 pair.set_first(entry.first);
1553 pair.set_second(entry.second);
1554 *result.Add() = std::move(pair);
1555 }
1556
1557 return result;
1558 }
1559
GetLastInsertBeforeInstruction(opt::IRContext * ir_context,uint32_t block_id,spv::Op opcode)1560 opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
1561 uint32_t block_id,
1562 spv::Op opcode) {
1563 // CFG::block uses std::map::at which throws an exception when |block_id| is
1564 // invalid. The error message is unhelpful, though. Thus, we test that
1565 // |block_id| is valid here.
1566 const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
1567 (void)label_inst; // Make compilers happy in release mode.
1568 assert(label_inst && label_inst->opcode() == spv::Op::OpLabel &&
1569 "|block_id| is invalid");
1570
1571 auto* block = ir_context->cfg()->block(block_id);
1572 auto it = block->rbegin();
1573 assert(it != block->rend() && "Basic block can't be empty");
1574
1575 if (block->GetMergeInst()) {
1576 ++it;
1577 assert(it != block->rend() &&
1578 "|block| must have at least two instructions:"
1579 "terminator and a merge instruction");
1580 }
1581
1582 return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
1583 }
1584
IdUseCanBeReplaced(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * use_instruction,uint32_t use_in_operand_index)1585 bool IdUseCanBeReplaced(opt::IRContext* ir_context,
1586 const TransformationContext& transformation_context,
1587 opt::Instruction* use_instruction,
1588 uint32_t use_in_operand_index) {
1589 if (spvOpcodeIsAccessChain(use_instruction->opcode()) &&
1590 use_in_operand_index > 0) {
1591 // A replacement for an irrelevant index in OpAccessChain must be clamped
1592 // first.
1593 if (transformation_context.GetFactManager()->IdIsIrrelevant(
1594 use_instruction->GetSingleWordInOperand(use_in_operand_index))) {
1595 return false;
1596 }
1597
1598 // This is an access chain index. If the (sub-)object being accessed by the
1599 // given index has struct type then we cannot replace the use, as it needs
1600 // to be an OpConstant.
1601
1602 // Get the top-level composite type that is being accessed.
1603 auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef(
1604 use_instruction->GetSingleWordInOperand(0));
1605 auto pointer_type =
1606 ir_context->get_type_mgr()->GetType(object_being_accessed->type_id());
1607 assert(pointer_type->AsPointer());
1608 auto composite_type_being_accessed =
1609 pointer_type->AsPointer()->pointee_type();
1610
1611 // Now walk the access chain, tracking the type of each sub-object of the
1612 // composite that is traversed, until the index of interest is reached.
1613 for (uint32_t index_in_operand = 1; index_in_operand < use_in_operand_index;
1614 index_in_operand++) {
1615 // For vectors, matrices and arrays, getting the type of the sub-object is
1616 // trivial. For the struct case, the sub-object type is field-sensitive,
1617 // and depends on the constant index that is used.
1618 if (composite_type_being_accessed->AsVector()) {
1619 composite_type_being_accessed =
1620 composite_type_being_accessed->AsVector()->element_type();
1621 } else if (composite_type_being_accessed->AsMatrix()) {
1622 composite_type_being_accessed =
1623 composite_type_being_accessed->AsMatrix()->element_type();
1624 } else if (composite_type_being_accessed->AsArray()) {
1625 composite_type_being_accessed =
1626 composite_type_being_accessed->AsArray()->element_type();
1627 } else if (composite_type_being_accessed->AsRuntimeArray()) {
1628 composite_type_being_accessed =
1629 composite_type_being_accessed->AsRuntimeArray()->element_type();
1630 } else {
1631 assert(composite_type_being_accessed->AsStruct());
1632 auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef(
1633 use_instruction->GetSingleWordInOperand(index_in_operand));
1634 assert(constant_index_instruction->opcode() == spv::Op::OpConstant);
1635 uint32_t member_index =
1636 constant_index_instruction->GetSingleWordInOperand(0);
1637 composite_type_being_accessed =
1638 composite_type_being_accessed->AsStruct()
1639 ->element_types()[member_index];
1640 }
1641 }
1642
1643 // We have found the composite type being accessed by the index we are
1644 // considering replacing. If it is a struct, then we cannot do the
1645 // replacement as struct indices must be constants.
1646 if (composite_type_being_accessed->AsStruct()) {
1647 return false;
1648 }
1649 }
1650
1651 if (use_instruction->opcode() == spv::Op::OpFunctionCall &&
1652 use_in_operand_index > 0) {
1653 // This is a function call argument. It is not allowed to have pointer
1654 // type.
1655
1656 // Get the definition of the function being called.
1657 auto function = ir_context->get_def_use_mgr()->GetDef(
1658 use_instruction->GetSingleWordInOperand(0));
1659 // From the function definition, get the function type.
1660 auto function_type = ir_context->get_def_use_mgr()->GetDef(
1661 function->GetSingleWordInOperand(1));
1662 // OpTypeFunction's 0-th input operand is the function return type, and the
1663 // function argument types follow. Because the arguments to OpFunctionCall
1664 // start from input operand 1, we can use |use_in_operand_index| to get the
1665 // type associated with this function argument.
1666 auto parameter_type = ir_context->get_type_mgr()->GetType(
1667 function_type->GetSingleWordInOperand(use_in_operand_index));
1668 if (parameter_type->AsPointer()) {
1669 return false;
1670 }
1671 }
1672
1673 if (use_instruction->opcode() == spv::Op::OpImageTexelPointer &&
1674 use_in_operand_index == 2) {
1675 // The OpImageTexelPointer instruction has a Sample parameter that in some
1676 // situations must be an id for the value 0. To guard against disrupting
1677 // that requirement, we do not replace this argument to that instruction.
1678 return false;
1679 }
1680
1681 if (ir_context->get_feature_mgr()->HasCapability(spv::Capability::Shader)) {
1682 // With the Shader capability, memory scope and memory semantics operands
1683 // are required to be constants, so they cannot be replaced arbitrarily.
1684 switch (use_instruction->opcode()) {
1685 case spv::Op::OpAtomicLoad:
1686 case spv::Op::OpAtomicStore:
1687 case spv::Op::OpAtomicExchange:
1688 case spv::Op::OpAtomicIIncrement:
1689 case spv::Op::OpAtomicIDecrement:
1690 case spv::Op::OpAtomicIAdd:
1691 case spv::Op::OpAtomicISub:
1692 case spv::Op::OpAtomicSMin:
1693 case spv::Op::OpAtomicUMin:
1694 case spv::Op::OpAtomicSMax:
1695 case spv::Op::OpAtomicUMax:
1696 case spv::Op::OpAtomicAnd:
1697 case spv::Op::OpAtomicOr:
1698 case spv::Op::OpAtomicXor:
1699 if (use_in_operand_index == 1 || use_in_operand_index == 2) {
1700 return false;
1701 }
1702 break;
1703 case spv::Op::OpAtomicCompareExchange:
1704 if (use_in_operand_index == 1 || use_in_operand_index == 2 ||
1705 use_in_operand_index == 3) {
1706 return false;
1707 }
1708 break;
1709 case spv::Op::OpAtomicCompareExchangeWeak:
1710 case spv::Op::OpAtomicFlagTestAndSet:
1711 case spv::Op::OpAtomicFlagClear:
1712 case spv::Op::OpAtomicFAddEXT:
1713 assert(false && "Not allowed with the Shader capability.");
1714 default:
1715 break;
1716 }
1717 }
1718
1719 return true;
1720 }
1721
MembersHaveBuiltInDecoration(opt::IRContext * ir_context,uint32_t struct_type_id)1722 bool MembersHaveBuiltInDecoration(opt::IRContext* ir_context,
1723 uint32_t struct_type_id) {
1724 const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(struct_type_id);
1725 assert(type_inst && type_inst->opcode() == spv::Op::OpTypeStruct &&
1726 "|struct_type_id| is not a result id of an OpTypeStruct");
1727
1728 uint32_t builtin_count = 0;
1729 ir_context->get_def_use_mgr()->ForEachUser(
1730 type_inst,
1731 [struct_type_id, &builtin_count](const opt::Instruction* user) {
1732 if (user->opcode() == spv::Op::OpMemberDecorate &&
1733 user->GetSingleWordInOperand(0) == struct_type_id &&
1734 static_cast<spv::Decoration>(user->GetSingleWordInOperand(2)) ==
1735 spv::Decoration::BuiltIn) {
1736 ++builtin_count;
1737 }
1738 });
1739
1740 assert((builtin_count == 0 || builtin_count == type_inst->NumInOperands()) &&
1741 "The module is invalid: either none or all of the members of "
1742 "|struct_type_id| may be builtin");
1743
1744 return builtin_count != 0;
1745 }
1746
HasBlockOrBufferBlockDecoration(opt::IRContext * ir_context,uint32_t id)1747 bool HasBlockOrBufferBlockDecoration(opt::IRContext* ir_context, uint32_t id) {
1748 for (auto decoration :
1749 {spv::Decoration::Block, spv::Decoration::BufferBlock}) {
1750 if (!ir_context->get_decoration_mgr()->WhileEachDecoration(
1751 id, uint32_t(decoration),
1752 [](const opt::Instruction & /*unused*/) -> bool {
1753 return false;
1754 })) {
1755 return true;
1756 }
1757 }
1758 return false;
1759 }
1760
SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(opt::BasicBlock * block_to_split,opt::Instruction * split_before)1761 bool SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(
1762 opt::BasicBlock* block_to_split, opt::Instruction* split_before) {
1763 std::set<uint32_t> sampled_image_result_ids;
1764 bool before_split = true;
1765
1766 // Check all the instructions in the block to split.
1767 for (auto& instruction : *block_to_split) {
1768 if (&instruction == &*split_before) {
1769 before_split = false;
1770 }
1771 if (before_split) {
1772 // If the instruction comes before the split and its opcode is
1773 // OpSampledImage, record its result id.
1774 if (instruction.opcode() == spv::Op::OpSampledImage) {
1775 sampled_image_result_ids.insert(instruction.result_id());
1776 }
1777 } else {
1778 // If the instruction comes after the split, check if ids
1779 // corresponding to OpSampledImage instructions defined before the split
1780 // are used, and return true if they are.
1781 if (!instruction.WhileEachInId(
1782 [&sampled_image_result_ids](uint32_t* id) -> bool {
1783 return !sampled_image_result_ids.count(*id);
1784 })) {
1785 return true;
1786 }
1787 }
1788 }
1789
1790 // No usage that would be separated from the definition has been found.
1791 return false;
1792 }
1793
InstructionHasNoSideEffects(const opt::Instruction & instruction)1794 bool InstructionHasNoSideEffects(const opt::Instruction& instruction) {
1795 switch (instruction.opcode()) {
1796 case spv::Op::OpUndef:
1797 case spv::Op::OpAccessChain:
1798 case spv::Op::OpInBoundsAccessChain:
1799 case spv::Op::OpArrayLength:
1800 case spv::Op::OpVectorExtractDynamic:
1801 case spv::Op::OpVectorInsertDynamic:
1802 case spv::Op::OpVectorShuffle:
1803 case spv::Op::OpCompositeConstruct:
1804 case spv::Op::OpCompositeExtract:
1805 case spv::Op::OpCompositeInsert:
1806 case spv::Op::OpCopyObject:
1807 case spv::Op::OpTranspose:
1808 case spv::Op::OpConvertFToU:
1809 case spv::Op::OpConvertFToS:
1810 case spv::Op::OpConvertSToF:
1811 case spv::Op::OpConvertUToF:
1812 case spv::Op::OpUConvert:
1813 case spv::Op::OpSConvert:
1814 case spv::Op::OpFConvert:
1815 case spv::Op::OpQuantizeToF16:
1816 case spv::Op::OpSatConvertSToU:
1817 case spv::Op::OpSatConvertUToS:
1818 case spv::Op::OpBitcast:
1819 case spv::Op::OpSNegate:
1820 case spv::Op::OpFNegate:
1821 case spv::Op::OpIAdd:
1822 case spv::Op::OpFAdd:
1823 case spv::Op::OpISub:
1824 case spv::Op::OpFSub:
1825 case spv::Op::OpIMul:
1826 case spv::Op::OpFMul:
1827 case spv::Op::OpUDiv:
1828 case spv::Op::OpSDiv:
1829 case spv::Op::OpFDiv:
1830 case spv::Op::OpUMod:
1831 case spv::Op::OpSRem:
1832 case spv::Op::OpSMod:
1833 case spv::Op::OpFRem:
1834 case spv::Op::OpFMod:
1835 case spv::Op::OpVectorTimesScalar:
1836 case spv::Op::OpMatrixTimesScalar:
1837 case spv::Op::OpVectorTimesMatrix:
1838 case spv::Op::OpMatrixTimesVector:
1839 case spv::Op::OpMatrixTimesMatrix:
1840 case spv::Op::OpOuterProduct:
1841 case spv::Op::OpDot:
1842 case spv::Op::OpIAddCarry:
1843 case spv::Op::OpISubBorrow:
1844 case spv::Op::OpUMulExtended:
1845 case spv::Op::OpSMulExtended:
1846 case spv::Op::OpAny:
1847 case spv::Op::OpAll:
1848 case spv::Op::OpIsNan:
1849 case spv::Op::OpIsInf:
1850 case spv::Op::OpIsFinite:
1851 case spv::Op::OpIsNormal:
1852 case spv::Op::OpSignBitSet:
1853 case spv::Op::OpLessOrGreater:
1854 case spv::Op::OpOrdered:
1855 case spv::Op::OpUnordered:
1856 case spv::Op::OpLogicalEqual:
1857 case spv::Op::OpLogicalNotEqual:
1858 case spv::Op::OpLogicalOr:
1859 case spv::Op::OpLogicalAnd:
1860 case spv::Op::OpLogicalNot:
1861 case spv::Op::OpSelect:
1862 case spv::Op::OpIEqual:
1863 case spv::Op::OpINotEqual:
1864 case spv::Op::OpUGreaterThan:
1865 case spv::Op::OpSGreaterThan:
1866 case spv::Op::OpUGreaterThanEqual:
1867 case spv::Op::OpSGreaterThanEqual:
1868 case spv::Op::OpULessThan:
1869 case spv::Op::OpSLessThan:
1870 case spv::Op::OpULessThanEqual:
1871 case spv::Op::OpSLessThanEqual:
1872 case spv::Op::OpFOrdEqual:
1873 case spv::Op::OpFUnordEqual:
1874 case spv::Op::OpFOrdNotEqual:
1875 case spv::Op::OpFUnordNotEqual:
1876 case spv::Op::OpFOrdLessThan:
1877 case spv::Op::OpFUnordLessThan:
1878 case spv::Op::OpFOrdGreaterThan:
1879 case spv::Op::OpFUnordGreaterThan:
1880 case spv::Op::OpFOrdLessThanEqual:
1881 case spv::Op::OpFUnordLessThanEqual:
1882 case spv::Op::OpFOrdGreaterThanEqual:
1883 case spv::Op::OpFUnordGreaterThanEqual:
1884 case spv::Op::OpShiftRightLogical:
1885 case spv::Op::OpShiftRightArithmetic:
1886 case spv::Op::OpShiftLeftLogical:
1887 case spv::Op::OpBitwiseOr:
1888 case spv::Op::OpBitwiseXor:
1889 case spv::Op::OpBitwiseAnd:
1890 case spv::Op::OpNot:
1891 case spv::Op::OpBitFieldInsert:
1892 case spv::Op::OpBitFieldSExtract:
1893 case spv::Op::OpBitFieldUExtract:
1894 case spv::Op::OpBitReverse:
1895 case spv::Op::OpBitCount:
1896 case spv::Op::OpCopyLogical:
1897 case spv::Op::OpPhi:
1898 case spv::Op::OpPtrEqual:
1899 case spv::Op::OpPtrNotEqual:
1900 return true;
1901 default:
1902 return false;
1903 }
1904 }
1905
GetReachableReturnBlocks(opt::IRContext * ir_context,uint32_t function_id)1906 std::set<uint32_t> GetReachableReturnBlocks(opt::IRContext* ir_context,
1907 uint32_t function_id) {
1908 auto function = ir_context->GetFunction(function_id);
1909 assert(function && "The function |function_id| must exist.");
1910
1911 std::set<uint32_t> result;
1912
1913 ir_context->cfg()->ForEachBlockInPostOrder(function->entry().get(),
1914 [&result](opt::BasicBlock* block) {
1915 if (block->IsReturn()) {
1916 result.emplace(block->id());
1917 }
1918 });
1919
1920 return result;
1921 }
1922
NewTerminatorPreservesDominationRules(opt::IRContext * ir_context,uint32_t block_id,opt::Instruction new_terminator)1923 bool NewTerminatorPreservesDominationRules(opt::IRContext* ir_context,
1924 uint32_t block_id,
1925 opt::Instruction new_terminator) {
1926 auto* mutated_block = MaybeFindBlock(ir_context, block_id);
1927 assert(mutated_block && "|block_id| is invalid");
1928
1929 ChangeTerminatorRAII change_terminator_raii(mutated_block,
1930 std::move(new_terminator));
1931 opt::DominatorAnalysis dominator_analysis;
1932 dominator_analysis.InitializeTree(*ir_context->cfg(),
1933 mutated_block->GetParent());
1934
1935 // Check that each dominator appears before each dominated block.
1936 std::unordered_map<uint32_t, size_t> positions;
1937 for (const auto& block : *mutated_block->GetParent()) {
1938 positions[block.id()] = positions.size();
1939 }
1940
1941 std::queue<uint32_t> q({mutated_block->GetParent()->begin()->id()});
1942 std::unordered_set<uint32_t> visited;
1943 while (!q.empty()) {
1944 auto block = q.front();
1945 q.pop();
1946 visited.insert(block);
1947
1948 auto success = ir_context->cfg()->block(block)->WhileEachSuccessorLabel(
1949 [&positions, &visited, &dominator_analysis, block, &q](uint32_t id) {
1950 if (id == block) {
1951 // Handle the case when loop header and continue target are the same
1952 // block.
1953 return true;
1954 }
1955
1956 if (dominator_analysis.Dominates(block, id) &&
1957 positions[block] > positions[id]) {
1958 // |block| dominates |id| but appears after |id| - violates
1959 // domination rules.
1960 return false;
1961 }
1962
1963 if (!visited.count(id)) {
1964 q.push(id);
1965 }
1966
1967 return true;
1968 });
1969
1970 if (!success) {
1971 return false;
1972 }
1973 }
1974
1975 // For each instruction in the |block->GetParent()| function check whether
1976 // all its dependencies satisfy domination rules (i.e. all id operands
1977 // dominate that instruction).
1978 for (const auto& block : *mutated_block->GetParent()) {
1979 if (!ir_context->IsReachable(block)) {
1980 // If some block is not reachable then we don't need to worry about the
1981 // preservation of domination rules for its instructions.
1982 continue;
1983 }
1984
1985 for (const auto& inst : block) {
1986 for (uint32_t i = 0; i < inst.NumInOperands();
1987 i += inst.opcode() == spv::Op::OpPhi ? 2 : 1) {
1988 const auto& operand = inst.GetInOperand(i);
1989 if (!spvIsInIdType(operand.type)) {
1990 continue;
1991 }
1992
1993 if (MaybeFindBlock(ir_context, operand.words[0])) {
1994 // Ignore operands that refer to OpLabel instructions.
1995 continue;
1996 }
1997
1998 const auto* dependency_block =
1999 ir_context->get_instr_block(operand.words[0]);
2000 if (!dependency_block) {
2001 // A global instruction always dominates all instructions in any
2002 // function.
2003 continue;
2004 }
2005
2006 auto domination_target_id = inst.opcode() == spv::Op::OpPhi
2007 ? inst.GetSingleWordInOperand(i + 1)
2008 : block.id();
2009
2010 if (!dominator_analysis.Dominates(dependency_block->id(),
2011 domination_target_id)) {
2012 return false;
2013 }
2014 }
2015 }
2016 }
2017
2018 return true;
2019 }
2020
GetFunctionIterator(opt::IRContext * ir_context,uint32_t function_id)2021 opt::Module::iterator GetFunctionIterator(opt::IRContext* ir_context,
2022 uint32_t function_id) {
2023 return std::find_if(ir_context->module()->begin(),
2024 ir_context->module()->end(),
2025 [function_id](const opt::Function& f) {
2026 return f.result_id() == function_id;
2027 });
2028 }
2029
2030 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3582): Add all
2031 // opcodes that are agnostic to signedness of operands to function.
2032 // This is not exhaustive yet.
IsAgnosticToSignednessOfOperand(spv::Op opcode,uint32_t use_in_operand_index)2033 bool IsAgnosticToSignednessOfOperand(spv::Op opcode,
2034 uint32_t use_in_operand_index) {
2035 switch (opcode) {
2036 case spv::Op::OpSNegate:
2037 case spv::Op::OpNot:
2038 case spv::Op::OpIAdd:
2039 case spv::Op::OpISub:
2040 case spv::Op::OpIMul:
2041 case spv::Op::OpSDiv:
2042 case spv::Op::OpSRem:
2043 case spv::Op::OpSMod:
2044 case spv::Op::OpShiftRightLogical:
2045 case spv::Op::OpShiftRightArithmetic:
2046 case spv::Op::OpShiftLeftLogical:
2047 case spv::Op::OpBitwiseOr:
2048 case spv::Op::OpBitwiseXor:
2049 case spv::Op::OpBitwiseAnd:
2050 case spv::Op::OpIEqual:
2051 case spv::Op::OpINotEqual:
2052 case spv::Op::OpULessThan:
2053 case spv::Op::OpSLessThan:
2054 case spv::Op::OpUGreaterThan:
2055 case spv::Op::OpSGreaterThan:
2056 case spv::Op::OpULessThanEqual:
2057 case spv::Op::OpSLessThanEqual:
2058 case spv::Op::OpUGreaterThanEqual:
2059 case spv::Op::OpSGreaterThanEqual:
2060 return true;
2061
2062 case spv::Op::OpAtomicStore:
2063 case spv::Op::OpAtomicExchange:
2064 case spv::Op::OpAtomicIAdd:
2065 case spv::Op::OpAtomicISub:
2066 case spv::Op::OpAtomicSMin:
2067 case spv::Op::OpAtomicUMin:
2068 case spv::Op::OpAtomicSMax:
2069 case spv::Op::OpAtomicUMax:
2070 case spv::Op::OpAtomicAnd:
2071 case spv::Op::OpAtomicOr:
2072 case spv::Op::OpAtomicXor:
2073 case spv::Op::OpAtomicFAddEXT: // Capability AtomicFloat32AddEXT,
2074 // AtomicFloat64AddEXT.
2075 assert(use_in_operand_index != 0 &&
2076 "Signedness check should not occur on a pointer operand.");
2077 return use_in_operand_index == 1 || use_in_operand_index == 2;
2078
2079 case spv::Op::OpAtomicCompareExchange:
2080 case spv::Op::OpAtomicCompareExchangeWeak: // Capability Kernel.
2081 assert(use_in_operand_index != 0 &&
2082 "Signedness check should not occur on a pointer operand.");
2083 return use_in_operand_index >= 1 && use_in_operand_index <= 3;
2084
2085 case spv::Op::OpAtomicLoad:
2086 case spv::Op::OpAtomicIIncrement:
2087 case spv::Op::OpAtomicIDecrement:
2088 case spv::Op::OpAtomicFlagTestAndSet: // Capability Kernel.
2089 case spv::Op::OpAtomicFlagClear: // Capability Kernel.
2090 assert(use_in_operand_index != 0 &&
2091 "Signedness check should not occur on a pointer operand.");
2092 return use_in_operand_index >= 1;
2093
2094 case spv::Op::OpAccessChain:
2095 // The signedness of indices does not matter.
2096 return use_in_operand_index > 0;
2097
2098 default:
2099 // Conservatively assume that the id cannot be swapped in other
2100 // instructions.
2101 return false;
2102 }
2103 }
2104
TypesAreCompatible(opt::IRContext * ir_context,spv::Op opcode,uint32_t use_in_operand_index,uint32_t type_id_1,uint32_t type_id_2)2105 bool TypesAreCompatible(opt::IRContext* ir_context, spv::Op opcode,
2106 uint32_t use_in_operand_index, uint32_t type_id_1,
2107 uint32_t type_id_2) {
2108 assert(ir_context->get_type_mgr()->GetType(type_id_1) &&
2109 ir_context->get_type_mgr()->GetType(type_id_2) &&
2110 "Type ids are invalid");
2111
2112 return type_id_1 == type_id_2 ||
2113 (IsAgnosticToSignednessOfOperand(opcode, use_in_operand_index) &&
2114 fuzzerutil::TypesAreEqualUpToSign(ir_context, type_id_1, type_id_2));
2115 }
2116
2117 } // namespace fuzzerutil
2118 } // namespace fuzz
2119 } // namespace spvtools
2120