1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/fuzzer_util.h"
16
17 #include <algorithm>
18 #include <unordered_set>
19
20 #include "source/opt/build_module.h"
21
22 namespace spvtools {
23 namespace fuzz {
24
25 namespace fuzzerutil {
26 namespace {
27
28 // A utility class that uses RAII to change and restore the terminator
29 // instruction of the |block|.
30 class ChangeTerminatorRAII {
31 public:
ChangeTerminatorRAII(opt::BasicBlock * block,opt::Instruction new_terminator)32 explicit ChangeTerminatorRAII(opt::BasicBlock* block,
33 opt::Instruction new_terminator)
34 : block_(block), old_terminator_(std::move(*block->terminator())) {
35 *block_->terminator() = std::move(new_terminator);
36 }
37
~ChangeTerminatorRAII()38 ~ChangeTerminatorRAII() {
39 *block_->terminator() = std::move(old_terminator_);
40 }
41
42 private:
43 opt::BasicBlock* block_;
44 opt::Instruction old_terminator_;
45 };
46
MaybeGetOpConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t type_id,bool is_irrelevant)47 uint32_t MaybeGetOpConstant(opt::IRContext* ir_context,
48 const TransformationContext& transformation_context,
49 const std::vector<uint32_t>& words,
50 uint32_t type_id, bool is_irrelevant) {
51 for (const auto& inst : ir_context->types_values()) {
52 if (inst.opcode() == SpvOpConstant && inst.type_id() == type_id &&
53 inst.GetInOperand(0).words == words &&
54 transformation_context.GetFactManager()->IdIsIrrelevant(
55 inst.result_id()) == is_irrelevant) {
56 return inst.result_id();
57 }
58 }
59
60 return 0;
61 }
62
63 } // namespace
64
65 const spvtools::MessageConsumer kSilentMessageConsumer =
66 [](spv_message_level_t, const char*, const spv_position_t&,
__anon4d46388f0202(spv_message_level_t, const char*, const spv_position_t&, const char*) 67 const char*) -> void {};
68
BuildIRContext(spv_target_env target_env,const spvtools::MessageConsumer & message_consumer,const std::vector<uint32_t> & binary_in,spv_validator_options validator_options,std::unique_ptr<spvtools::opt::IRContext> * ir_context)69 bool BuildIRContext(spv_target_env target_env,
70 const spvtools::MessageConsumer& message_consumer,
71 const std::vector<uint32_t>& binary_in,
72 spv_validator_options validator_options,
73 std::unique_ptr<spvtools::opt::IRContext>* ir_context) {
74 SpirvTools tools(target_env);
75 tools.SetMessageConsumer(message_consumer);
76 if (!tools.IsValid()) {
77 message_consumer(SPV_MSG_ERROR, nullptr, {},
78 "Failed to create SPIRV-Tools interface; stopping.");
79 return false;
80 }
81
82 // Initial binary should be valid.
83 if (!tools.Validate(binary_in.data(), binary_in.size(), validator_options)) {
84 message_consumer(SPV_MSG_ERROR, nullptr, {},
85 "Initial binary is invalid; stopping.");
86 return false;
87 }
88
89 // Build the module from the input binary.
90 auto result = BuildModule(target_env, message_consumer, binary_in.data(),
91 binary_in.size());
92 assert(result && "IRContext must be valid");
93 *ir_context = std::move(result);
94 return true;
95 }
96
IsFreshId(opt::IRContext * context,uint32_t id)97 bool IsFreshId(opt::IRContext* context, uint32_t id) {
98 return !context->get_def_use_mgr()->GetDef(id);
99 }
100
UpdateModuleIdBound(opt::IRContext * context,uint32_t id)101 void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
102 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
103 // case where the maximum id bound is reached.
104 context->module()->SetIdBound(
105 std::max(context->module()->id_bound(), id + 1));
106 }
107
MaybeFindBlock(opt::IRContext * context,uint32_t maybe_block_id)108 opt::BasicBlock* MaybeFindBlock(opt::IRContext* context,
109 uint32_t maybe_block_id) {
110 auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
111 if (inst == nullptr) {
112 // No instruction defining this id was found.
113 return nullptr;
114 }
115 if (inst->opcode() != SpvOpLabel) {
116 // The instruction defining the id is not a label, so it cannot be a block
117 // id.
118 return nullptr;
119 }
120 return context->cfg()->block(maybe_block_id);
121 }
122
PhiIdsOkForNewEdge(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)123 bool PhiIdsOkForNewEdge(
124 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
125 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
126 if (bb_from->IsSuccessor(bb_to)) {
127 // There is already an edge from |from_block| to |to_block|, so there is
128 // no need to extend OpPhi instructions. Do not allow phi ids to be
129 // present. This might turn out to be too strict; perhaps it would be OK
130 // just to ignore the ids in this case.
131 return phi_ids.empty();
132 }
133 // The edge would add a previously non-existent edge from |from_block| to
134 // |to_block|, so we go through the given phi ids and check that they exactly
135 // match the OpPhi instructions in |to_block|.
136 uint32_t phi_index = 0;
137 // An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
138 // makes sense here because we need to increment |phi_index| for each OpPhi
139 // instruction.
140 for (auto& inst : *bb_to) {
141 if (inst.opcode() != SpvOpPhi) {
142 // The OpPhi instructions all occur at the start of the block; if we find
143 // a non-OpPhi then we have seen them all.
144 break;
145 }
146 if (phi_index == static_cast<uint32_t>(phi_ids.size())) {
147 // Not enough phi ids have been provided to account for the OpPhi
148 // instructions.
149 return false;
150 }
151 // Look for an instruction defining the next phi id.
152 opt::Instruction* phi_extension =
153 context->get_def_use_mgr()->GetDef(phi_ids[phi_index]);
154 if (!phi_extension) {
155 // The id given to extend this OpPhi does not exist.
156 return false;
157 }
158 if (phi_extension->type_id() != inst.type_id()) {
159 // The instruction given to extend this OpPhi either does not have a type
160 // or its type does not match that of the OpPhi.
161 return false;
162 }
163
164 if (context->get_instr_block(phi_extension)) {
165 // The instruction defining the phi id has an associated block (i.e., it
166 // is not a global value). Check whether its definition dominates the
167 // exit of |from_block|.
168 auto dominator_analysis =
169 context->GetDominatorAnalysis(bb_from->GetParent());
170 if (!dominator_analysis->Dominates(phi_extension,
171 bb_from->terminator())) {
172 // The given id is no good as its definition does not dominate the exit
173 // of |from_block|
174 return false;
175 }
176 }
177 phi_index++;
178 }
179 // We allow some of the ids provided for extending OpPhi instructions to be
180 // unused. Their presence does no harm, and requiring a perfect match may
181 // make transformations less likely to cleanly apply.
182 return true;
183 }
184
CreateUnreachableEdgeInstruction(opt::IRContext * ir_context,uint32_t bb_from_id,uint32_t bb_to_id,uint32_t bool_id)185 opt::Instruction CreateUnreachableEdgeInstruction(opt::IRContext* ir_context,
186 uint32_t bb_from_id,
187 uint32_t bb_to_id,
188 uint32_t bool_id) {
189 const auto* bb_from = MaybeFindBlock(ir_context, bb_from_id);
190 assert(bb_from && "|bb_from_id| is invalid");
191 assert(MaybeFindBlock(ir_context, bb_to_id) && "|bb_to_id| is invalid");
192 assert(bb_from->terminator()->opcode() == SpvOpBranch &&
193 "Precondition on terminator of bb_from is not satisfied");
194
195 // Get the id of the boolean constant to be used as the condition.
196 auto condition_inst = ir_context->get_def_use_mgr()->GetDef(bool_id);
197 assert(condition_inst &&
198 (condition_inst->opcode() == SpvOpConstantTrue ||
199 condition_inst->opcode() == SpvOpConstantFalse) &&
200 "|bool_id| is invalid");
201
202 auto condition_value = condition_inst->opcode() == SpvOpConstantTrue;
203 auto successor_id = bb_from->terminator()->GetSingleWordInOperand(0);
204
205 // Add the dead branch, by turning OpBranch into OpBranchConditional, and
206 // ordering the targets depending on whether the given boolean corresponds to
207 // true or false.
208 return opt::Instruction(
209 ir_context, SpvOpBranchConditional, 0, 0,
210 {{SPV_OPERAND_TYPE_ID, {bool_id}},
211 {SPV_OPERAND_TYPE_ID, {condition_value ? successor_id : bb_to_id}},
212 {SPV_OPERAND_TYPE_ID, {condition_value ? bb_to_id : successor_id}}});
213 }
214
AddUnreachableEdgeAndUpdateOpPhis(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,uint32_t bool_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)215 void AddUnreachableEdgeAndUpdateOpPhis(
216 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
217 uint32_t bool_id,
218 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
219 assert(PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) &&
220 "Precondition on phi_ids is not satisfied");
221
222 const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
223 *bb_from->terminator() = CreateUnreachableEdgeInstruction(
224 context, bb_from->id(), bb_to->id(), bool_id);
225
226 // Update OpPhi instructions in the target block if this branch adds a
227 // previously non-existent edge from source to target.
228 if (!from_to_edge_already_exists) {
229 uint32_t phi_index = 0;
230 for (auto& inst : *bb_to) {
231 if (inst.opcode() != SpvOpPhi) {
232 break;
233 }
234 assert(phi_index < static_cast<uint32_t>(phi_ids.size()) &&
235 "There should be at least one phi id per OpPhi instruction.");
236 inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}});
237 inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}});
238 phi_index++;
239 }
240 }
241 }
242
BlockIsBackEdge(opt::IRContext * context,uint32_t block_id,uint32_t loop_header_id)243 bool BlockIsBackEdge(opt::IRContext* context, uint32_t block_id,
244 uint32_t loop_header_id) {
245 auto block = context->cfg()->block(block_id);
246 auto loop_header = context->cfg()->block(loop_header_id);
247
248 // |block| and |loop_header| must be defined, |loop_header| must be in fact
249 // loop header and |block| must branch to it.
250 if (!(block && loop_header && loop_header->IsLoopHeader() &&
251 block->IsSuccessor(loop_header))) {
252 return false;
253 }
254
255 // |block| must be reachable and be dominated by |loop_header|.
256 opt::DominatorAnalysis* dominator_analysis =
257 context->GetDominatorAnalysis(loop_header->GetParent());
258 return context->IsReachable(*block) &&
259 dominator_analysis->Dominates(loop_header, block);
260 }
261
BlockIsInLoopContinueConstruct(opt::IRContext * context,uint32_t block_id,uint32_t maybe_loop_header_id)262 bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
263 uint32_t maybe_loop_header_id) {
264 // We deem a block to be part of a loop's continue construct if the loop's
265 // continue target dominates the block.
266 auto containing_construct_block = context->cfg()->block(maybe_loop_header_id);
267 if (containing_construct_block->IsLoopHeader()) {
268 auto continue_target = containing_construct_block->ContinueBlockId();
269 if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
270 ->Dominates(continue_target, block_id)) {
271 return true;
272 }
273 }
274 return false;
275 }
276
GetIteratorForInstruction(opt::BasicBlock * block,const opt::Instruction * inst)277 opt::BasicBlock::iterator GetIteratorForInstruction(
278 opt::BasicBlock* block, const opt::Instruction* inst) {
279 for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
280 if (inst == &*inst_it) {
281 return inst_it;
282 }
283 }
284 return block->end();
285 }
286
CanInsertOpcodeBeforeInstruction(SpvOp opcode,const opt::BasicBlock::iterator & instruction_in_block)287 bool CanInsertOpcodeBeforeInstruction(
288 SpvOp opcode, const opt::BasicBlock::iterator& instruction_in_block) {
289 if (instruction_in_block->PreviousNode() &&
290 (instruction_in_block->PreviousNode()->opcode() == SpvOpLoopMerge ||
291 instruction_in_block->PreviousNode()->opcode() == SpvOpSelectionMerge)) {
292 // We cannot insert directly after a merge instruction.
293 return false;
294 }
295 if (opcode != SpvOpVariable &&
296 instruction_in_block->opcode() == SpvOpVariable) {
297 // We cannot insert a non-OpVariable instruction directly before a
298 // variable; variables in a function must be contiguous in the entry block.
299 return false;
300 }
301 // We cannot insert a non-OpPhi instruction directly before an OpPhi, because
302 // OpPhi instructions need to be contiguous at the start of a block.
303 return opcode == SpvOpPhi || instruction_in_block->opcode() != SpvOpPhi;
304 }
305
CanMakeSynonymOf(opt::IRContext * ir_context,const TransformationContext & transformation_context,const opt::Instruction & inst)306 bool CanMakeSynonymOf(opt::IRContext* ir_context,
307 const TransformationContext& transformation_context,
308 const opt::Instruction& inst) {
309 if (inst.opcode() == SpvOpSampledImage) {
310 // The SPIR-V data rules say that only very specific instructions may
311 // may consume the result id of an OpSampledImage, and this excludes the
312 // instructions that are used for making synonyms.
313 return false;
314 }
315 if (!inst.HasResultId()) {
316 // We can only make a synonym of an instruction that generates an id.
317 return false;
318 }
319 if (transformation_context.GetFactManager()->IdIsIrrelevant(
320 inst.result_id())) {
321 // An irrelevant id can't be a synonym of anything.
322 return false;
323 }
324 if (!inst.type_id()) {
325 // We can only make a synonym of an instruction that has a type.
326 return false;
327 }
328 auto type_inst = ir_context->get_def_use_mgr()->GetDef(inst.type_id());
329 if (type_inst->opcode() == SpvOpTypeVoid) {
330 // We only make synonyms of instructions that define objects, and an object
331 // cannot have void type.
332 return false;
333 }
334 if (type_inst->opcode() == SpvOpTypePointer) {
335 switch (inst.opcode()) {
336 case SpvOpConstantNull:
337 case SpvOpUndef:
338 // We disallow making synonyms of null or undefined pointers. This is
339 // to provide the property that if the original shader exhibited no bad
340 // pointer accesses, the transformed shader will not either.
341 return false;
342 default:
343 break;
344 }
345 }
346
347 // We do not make synonyms of objects that have decorations: if the synonym is
348 // not decorated analogously, using the original object vs. its synonymous
349 // form may not be equivalent.
350 return ir_context->get_decoration_mgr()
351 ->GetDecorationsFor(inst.result_id(), true)
352 .empty();
353 }
354
IsCompositeType(const opt::analysis::Type * type)355 bool IsCompositeType(const opt::analysis::Type* type) {
356 return type && (type->AsArray() || type->AsMatrix() || type->AsStruct() ||
357 type->AsVector());
358 }
359
RepeatedFieldToVector(const google::protobuf::RepeatedField<uint32_t> & repeated_field)360 std::vector<uint32_t> RepeatedFieldToVector(
361 const google::protobuf::RepeatedField<uint32_t>& repeated_field) {
362 std::vector<uint32_t> result;
363 for (auto i : repeated_field) {
364 result.push_back(i);
365 }
366 return result;
367 }
368
WalkOneCompositeTypeIndex(opt::IRContext * context,uint32_t base_object_type_id,uint32_t index)369 uint32_t WalkOneCompositeTypeIndex(opt::IRContext* context,
370 uint32_t base_object_type_id,
371 uint32_t index) {
372 auto should_be_composite_type =
373 context->get_def_use_mgr()->GetDef(base_object_type_id);
374 assert(should_be_composite_type && "The type should exist.");
375 switch (should_be_composite_type->opcode()) {
376 case SpvOpTypeArray: {
377 auto array_length = GetArraySize(*should_be_composite_type, context);
378 if (array_length == 0 || index >= array_length) {
379 return 0;
380 }
381 return should_be_composite_type->GetSingleWordInOperand(0);
382 }
383 case SpvOpTypeMatrix:
384 case SpvOpTypeVector: {
385 auto count = should_be_composite_type->GetSingleWordInOperand(1);
386 if (index >= count) {
387 return 0;
388 }
389 return should_be_composite_type->GetSingleWordInOperand(0);
390 }
391 case SpvOpTypeStruct: {
392 if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
393 return 0;
394 }
395 return should_be_composite_type->GetSingleWordInOperand(index);
396 }
397 default:
398 return 0;
399 }
400 }
401
WalkCompositeTypeIndices(opt::IRContext * context,uint32_t base_object_type_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & indices)402 uint32_t WalkCompositeTypeIndices(
403 opt::IRContext* context, uint32_t base_object_type_id,
404 const google::protobuf::RepeatedField<google::protobuf::uint32>& indices) {
405 uint32_t sub_object_type_id = base_object_type_id;
406 for (auto index : indices) {
407 sub_object_type_id =
408 WalkOneCompositeTypeIndex(context, sub_object_type_id, index);
409 if (!sub_object_type_id) {
410 return 0;
411 }
412 }
413 return sub_object_type_id;
414 }
415
GetNumberOfStructMembers(const opt::Instruction & struct_type_instruction)416 uint32_t GetNumberOfStructMembers(
417 const opt::Instruction& struct_type_instruction) {
418 assert(struct_type_instruction.opcode() == SpvOpTypeStruct &&
419 "An OpTypeStruct instruction is required here.");
420 return struct_type_instruction.NumInOperands();
421 }
422
GetArraySize(const opt::Instruction & array_type_instruction,opt::IRContext * context)423 uint32_t GetArraySize(const opt::Instruction& array_type_instruction,
424 opt::IRContext* context) {
425 auto array_length_constant =
426 context->get_constant_mgr()
427 ->GetConstantFromInst(context->get_def_use_mgr()->GetDef(
428 array_type_instruction.GetSingleWordInOperand(1)))
429 ->AsIntConstant();
430 if (array_length_constant->words().size() != 1) {
431 return 0;
432 }
433 return array_length_constant->GetU32();
434 }
435
GetBoundForCompositeIndex(const opt::Instruction & composite_type_inst,opt::IRContext * ir_context)436 uint32_t GetBoundForCompositeIndex(const opt::Instruction& composite_type_inst,
437 opt::IRContext* ir_context) {
438 switch (composite_type_inst.opcode()) {
439 case SpvOpTypeArray:
440 return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
441 case SpvOpTypeMatrix:
442 case SpvOpTypeVector:
443 return composite_type_inst.GetSingleWordInOperand(1);
444 case SpvOpTypeStruct: {
445 return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
446 }
447 case SpvOpTypeRuntimeArray:
448 assert(false &&
449 "GetBoundForCompositeIndex should not be invoked with an "
450 "OpTypeRuntimeArray, which does not have a static bound.");
451 return 0;
452 default:
453 assert(false && "Unknown composite type.");
454 return 0;
455 }
456 }
457
GetMemorySemanticsForStorageClass(SpvStorageClass storage_class)458 SpvMemorySemanticsMask GetMemorySemanticsForStorageClass(
459 SpvStorageClass storage_class) {
460 switch (storage_class) {
461 case SpvStorageClassWorkgroup:
462 return SpvMemorySemanticsWorkgroupMemoryMask;
463
464 case SpvStorageClassStorageBuffer:
465 case SpvStorageClassPhysicalStorageBuffer:
466 return SpvMemorySemanticsUniformMemoryMask;
467
468 case SpvStorageClassCrossWorkgroup:
469 return SpvMemorySemanticsCrossWorkgroupMemoryMask;
470
471 case SpvStorageClassAtomicCounter:
472 return SpvMemorySemanticsAtomicCounterMemoryMask;
473
474 case SpvStorageClassImage:
475 return SpvMemorySemanticsImageMemoryMask;
476
477 default:
478 return SpvMemorySemanticsMaskNone;
479 }
480 }
481
IsValid(const opt::IRContext * context,spv_validator_options validator_options,MessageConsumer consumer)482 bool IsValid(const opt::IRContext* context,
483 spv_validator_options validator_options,
484 MessageConsumer consumer) {
485 std::vector<uint32_t> binary;
486 context->module()->ToBinary(&binary, false);
487 SpirvTools tools(context->grammar().target_env());
488 tools.SetMessageConsumer(std::move(consumer));
489 return tools.Validate(binary.data(), binary.size(), validator_options);
490 }
491
IsValidAndWellFormed(const opt::IRContext * ir_context,spv_validator_options validator_options,MessageConsumer consumer)492 bool IsValidAndWellFormed(const opt::IRContext* ir_context,
493 spv_validator_options validator_options,
494 MessageConsumer consumer) {
495 if (!IsValid(ir_context, validator_options, consumer)) {
496 // Expression to dump |ir_context| to /data/temp/shader.spv:
497 // DumpShader(ir_context, "/data/temp/shader.spv")
498 consumer(SPV_MSG_INFO, nullptr, {},
499 "Module is invalid (set a breakpoint to inspect).");
500 return false;
501 }
502 // Check that all blocks in the module have appropriate parent functions.
503 for (auto& function : *ir_context->module()) {
504 for (auto& block : function) {
505 if (block.GetParent() == nullptr) {
506 std::stringstream ss;
507 ss << "Block " << block.id() << " has no parent; its parent should be "
508 << function.result_id() << " (set a breakpoint to inspect).";
509 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
510 return false;
511 }
512 if (block.GetParent() != &function) {
513 std::stringstream ss;
514 ss << "Block " << block.id() << " should have parent "
515 << function.result_id() << " but instead has parent "
516 << block.GetParent() << " (set a breakpoint to inspect).";
517 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
518 return false;
519 }
520 }
521 }
522
523 // Check that all instructions have distinct unique ids. We map each unique
524 // id to the first instruction it is observed to be associated with so that
525 // if we encounter a duplicate we have access to the previous instruction -
526 // this is a useful aid to debugging.
527 std::unordered_map<uint32_t, opt::Instruction*> unique_ids;
528 bool found_duplicate = false;
529 ir_context->module()->ForEachInst([&consumer, &found_duplicate, ir_context,
530 &unique_ids](opt::Instruction* inst) {
531 (void)ir_context; // Only used in an assertion; keep release-mode compilers
532 // happy.
533 assert(inst->context() == ir_context &&
534 "Instruction has wrong IR context.");
535 if (unique_ids.count(inst->unique_id()) != 0) {
536 consumer(SPV_MSG_INFO, nullptr, {},
537 "Two instructions have the same unique id (set a breakpoint to "
538 "inspect).");
539 found_duplicate = true;
540 }
541 unique_ids.insert({inst->unique_id(), inst});
542 });
543 return !found_duplicate;
544 }
545
CloneIRContext(opt::IRContext * context)546 std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) {
547 std::vector<uint32_t> binary;
548 context->module()->ToBinary(&binary, false);
549 return BuildModule(context->grammar().target_env(), nullptr, binary.data(),
550 binary.size());
551 }
552
IsNonFunctionTypeId(opt::IRContext * ir_context,uint32_t id)553 bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id) {
554 auto type = ir_context->get_type_mgr()->GetType(id);
555 return type && !type->AsFunction();
556 }
557
IsMergeOrContinue(opt::IRContext * ir_context,uint32_t block_id)558 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) {
559 bool result = false;
560 ir_context->get_def_use_mgr()->WhileEachUse(
561 block_id,
562 [&result](const opt::Instruction* use_instruction,
563 uint32_t /*unused*/) -> bool {
564 switch (use_instruction->opcode()) {
565 case SpvOpLoopMerge:
566 case SpvOpSelectionMerge:
567 result = true;
568 return false;
569 default:
570 return true;
571 }
572 });
573 return result;
574 }
575
GetLoopFromMergeBlock(opt::IRContext * ir_context,uint32_t merge_block_id)576 uint32_t GetLoopFromMergeBlock(opt::IRContext* ir_context,
577 uint32_t merge_block_id) {
578 uint32_t result = 0;
579 ir_context->get_def_use_mgr()->WhileEachUse(
580 merge_block_id,
581 [ir_context, &result](opt::Instruction* use_instruction,
582 uint32_t use_index) -> bool {
583 switch (use_instruction->opcode()) {
584 case SpvOpLoopMerge:
585 // The merge block operand is the first operand in OpLoopMerge.
586 if (use_index == 0) {
587 result = ir_context->get_instr_block(use_instruction)->id();
588 return false;
589 }
590 return true;
591 default:
592 return true;
593 }
594 });
595 return result;
596 }
597
FindFunctionType(opt::IRContext * ir_context,const std::vector<uint32_t> & type_ids)598 uint32_t FindFunctionType(opt::IRContext* ir_context,
599 const std::vector<uint32_t>& type_ids) {
600 // Look through the existing types for a match.
601 for (auto& type_or_value : ir_context->types_values()) {
602 if (type_or_value.opcode() != SpvOpTypeFunction) {
603 // We are only interested in function types.
604 continue;
605 }
606 if (type_or_value.NumInOperands() != type_ids.size()) {
607 // Not a match: different numbers of arguments.
608 continue;
609 }
610 // Check whether the return type and argument types match.
611 bool input_operands_match = true;
612 for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
613 if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
614 input_operands_match = false;
615 break;
616 }
617 }
618 if (input_operands_match) {
619 // Everything matches.
620 return type_or_value.result_id();
621 }
622 }
623 // No match was found.
624 return 0;
625 }
626
GetFunctionType(opt::IRContext * context,const opt::Function * function)627 opt::Instruction* GetFunctionType(opt::IRContext* context,
628 const opt::Function* function) {
629 uint32_t type_id = function->DefInst().GetSingleWordInOperand(1);
630 return context->get_def_use_mgr()->GetDef(type_id);
631 }
632
FindFunction(opt::IRContext * ir_context,uint32_t function_id)633 opt::Function* FindFunction(opt::IRContext* ir_context, uint32_t function_id) {
634 for (auto& function : *ir_context->module()) {
635 if (function.result_id() == function_id) {
636 return &function;
637 }
638 }
639 return nullptr;
640 }
641
FunctionContainsOpKillOrUnreachable(const opt::Function & function)642 bool FunctionContainsOpKillOrUnreachable(const opt::Function& function) {
643 for (auto& block : function) {
644 if (block.terminator()->opcode() == SpvOpKill ||
645 block.terminator()->opcode() == SpvOpUnreachable) {
646 return true;
647 }
648 }
649 return false;
650 }
651
FunctionIsEntryPoint(opt::IRContext * context,uint32_t function_id)652 bool FunctionIsEntryPoint(opt::IRContext* context, uint32_t function_id) {
653 for (auto& entry_point : context->module()->entry_points()) {
654 if (entry_point.GetSingleWordInOperand(1) == function_id) {
655 return true;
656 }
657 }
658 return false;
659 }
660
IdIsAvailableAtUse(opt::IRContext * context,opt::Instruction * use_instruction,uint32_t use_input_operand_index,uint32_t id)661 bool IdIsAvailableAtUse(opt::IRContext* context,
662 opt::Instruction* use_instruction,
663 uint32_t use_input_operand_index, uint32_t id) {
664 assert(context->get_instr_block(use_instruction) &&
665 "|use_instruction| must be in a basic block");
666
667 auto defining_instruction = context->get_def_use_mgr()->GetDef(id);
668 auto enclosing_function =
669 context->get_instr_block(use_instruction)->GetParent();
670 // If the id a function parameter, it needs to be associated with the
671 // function containing the use.
672 if (defining_instruction->opcode() == SpvOpFunctionParameter) {
673 return InstructionIsFunctionParameter(defining_instruction,
674 enclosing_function);
675 }
676 if (!context->get_instr_block(id)) {
677 // The id must be at global scope.
678 return true;
679 }
680 if (defining_instruction == use_instruction) {
681 // It is not OK for a definition to use itself.
682 return false;
683 }
684 if (!context->IsReachable(*context->get_instr_block(use_instruction)) ||
685 !context->IsReachable(*context->get_instr_block(id))) {
686 // Skip unreachable blocks.
687 return false;
688 }
689 auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function);
690 if (use_instruction->opcode() == SpvOpPhi) {
691 // In the case where the use is an operand to OpPhi, it is actually the
692 // *parent* block associated with the operand that must be dominated by
693 // the synonym.
694 auto parent_block =
695 use_instruction->GetSingleWordInOperand(use_input_operand_index + 1);
696 return dominator_analysis->Dominates(
697 context->get_instr_block(defining_instruction)->id(), parent_block);
698 }
699 return dominator_analysis->Dominates(defining_instruction, use_instruction);
700 }
701
IdIsAvailableBeforeInstruction(opt::IRContext * context,opt::Instruction * instruction,uint32_t id)702 bool IdIsAvailableBeforeInstruction(opt::IRContext* context,
703 opt::Instruction* instruction,
704 uint32_t id) {
705 assert(context->get_instr_block(instruction) &&
706 "|instruction| must be in a basic block");
707
708 auto id_definition = context->get_def_use_mgr()->GetDef(id);
709 auto function_enclosing_instruction =
710 context->get_instr_block(instruction)->GetParent();
711 // If the id a function parameter, it needs to be associated with the
712 // function containing the instruction.
713 if (id_definition->opcode() == SpvOpFunctionParameter) {
714 return InstructionIsFunctionParameter(id_definition,
715 function_enclosing_instruction);
716 }
717 if (!context->get_instr_block(id)) {
718 // The id is at global scope.
719 return true;
720 }
721 if (id_definition == instruction) {
722 // The instruction is not available right before its own definition.
723 return false;
724 }
725 const auto* dominator_analysis =
726 context->GetDominatorAnalysis(function_enclosing_instruction);
727 if (context->IsReachable(*context->get_instr_block(instruction)) &&
728 context->IsReachable(*context->get_instr_block(id)) &&
729 dominator_analysis->Dominates(id_definition, instruction)) {
730 // The id's definition dominates the instruction, and both the definition
731 // and the instruction are in reachable blocks, thus the id is available at
732 // the instruction.
733 return true;
734 }
735 if (id_definition->opcode() == SpvOpVariable &&
736 function_enclosing_instruction ==
737 context->get_instr_block(id)->GetParent()) {
738 assert(!context->IsReachable(*context->get_instr_block(instruction)) &&
739 "If the instruction were in a reachable block we should already "
740 "have returned true.");
741 // The id is a variable and it is in the same function as |instruction|.
742 // This is OK despite |instruction| being unreachable.
743 return true;
744 }
745 return false;
746 }
747
InstructionIsFunctionParameter(opt::Instruction * instruction,opt::Function * function)748 bool InstructionIsFunctionParameter(opt::Instruction* instruction,
749 opt::Function* function) {
750 if (instruction->opcode() != SpvOpFunctionParameter) {
751 return false;
752 }
753 bool found_parameter = false;
754 function->ForEachParam(
755 [instruction, &found_parameter](opt::Instruction* param) {
756 if (param == instruction) {
757 found_parameter = true;
758 }
759 });
760 return found_parameter;
761 }
762
GetTypeId(opt::IRContext * context,uint32_t result_id)763 uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
764 const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
765 assert(inst && "|result_id| is invalid");
766 return inst->type_id();
767 }
768
GetPointeeTypeIdFromPointerType(opt::Instruction * pointer_type_inst)769 uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
770 assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
771 "Precondition: |pointer_type_inst| must be OpTypePointer.");
772 return pointer_type_inst->GetSingleWordInOperand(1);
773 }
774
GetPointeeTypeIdFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)775 uint32_t GetPointeeTypeIdFromPointerType(opt::IRContext* context,
776 uint32_t pointer_type_id) {
777 return GetPointeeTypeIdFromPointerType(
778 context->get_def_use_mgr()->GetDef(pointer_type_id));
779 }
780
GetStorageClassFromPointerType(opt::Instruction * pointer_type_inst)781 SpvStorageClass GetStorageClassFromPointerType(
782 opt::Instruction* pointer_type_inst) {
783 assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
784 "Precondition: |pointer_type_inst| must be OpTypePointer.");
785 return static_cast<SpvStorageClass>(
786 pointer_type_inst->GetSingleWordInOperand(0));
787 }
788
GetStorageClassFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)789 SpvStorageClass GetStorageClassFromPointerType(opt::IRContext* context,
790 uint32_t pointer_type_id) {
791 return GetStorageClassFromPointerType(
792 context->get_def_use_mgr()->GetDef(pointer_type_id));
793 }
794
MaybeGetPointerType(opt::IRContext * context,uint32_t pointee_type_id,SpvStorageClass storage_class)795 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
796 SpvStorageClass storage_class) {
797 for (auto& inst : context->types_values()) {
798 switch (inst.opcode()) {
799 case SpvOpTypePointer:
800 if (inst.GetSingleWordInOperand(0) == storage_class &&
801 inst.GetSingleWordInOperand(1) == pointee_type_id) {
802 return inst.result_id();
803 }
804 break;
805 default:
806 break;
807 }
808 }
809 return 0;
810 }
811
InOperandIndexFromOperandIndex(const opt::Instruction & inst,uint32_t absolute_index)812 uint32_t InOperandIndexFromOperandIndex(const opt::Instruction& inst,
813 uint32_t absolute_index) {
814 // Subtract the number of non-input operands from the index
815 return absolute_index - inst.NumOperands() + inst.NumInOperands();
816 }
817
IsNullConstantSupported(opt::IRContext * ir_context,const opt::Instruction & type_inst)818 bool IsNullConstantSupported(opt::IRContext* ir_context,
819 const opt::Instruction& type_inst) {
820 switch (type_inst.opcode()) {
821 case SpvOpTypeArray:
822 case SpvOpTypeBool:
823 case SpvOpTypeDeviceEvent:
824 case SpvOpTypeEvent:
825 case SpvOpTypeFloat:
826 case SpvOpTypeInt:
827 case SpvOpTypeMatrix:
828 case SpvOpTypeQueue:
829 case SpvOpTypeReserveId:
830 case SpvOpTypeVector:
831 case SpvOpTypeStruct:
832 return true;
833 case SpvOpTypePointer:
834 // Null pointers are allowed if the VariablePointers capability is
835 // enabled, or if the VariablePointersStorageBuffer capability is enabled
836 // and the pointer type has StorageBuffer as its storage class.
837 if (ir_context->get_feature_mgr()->HasCapability(
838 SpvCapabilityVariablePointers)) {
839 return true;
840 }
841 if (ir_context->get_feature_mgr()->HasCapability(
842 SpvCapabilityVariablePointersStorageBuffer)) {
843 return type_inst.GetSingleWordInOperand(0) ==
844 SpvStorageClassStorageBuffer;
845 }
846 return false;
847 default:
848 return false;
849 }
850 }
851
GlobalVariablesMustBeDeclaredInEntryPointInterfaces(const opt::IRContext * ir_context)852 bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
853 const opt::IRContext* ir_context) {
854 // TODO(afd): We capture the environments for which this requirement holds.
855 // The check should be refined on demand for other target environments.
856 switch (ir_context->grammar().target_env()) {
857 case SPV_ENV_UNIVERSAL_1_0:
858 case SPV_ENV_UNIVERSAL_1_1:
859 case SPV_ENV_UNIVERSAL_1_2:
860 case SPV_ENV_UNIVERSAL_1_3:
861 case SPV_ENV_VULKAN_1_0:
862 case SPV_ENV_VULKAN_1_1:
863 return false;
864 default:
865 return true;
866 }
867 }
868
AddVariableIdToEntryPointInterfaces(opt::IRContext * context,uint32_t id)869 void AddVariableIdToEntryPointInterfaces(opt::IRContext* context, uint32_t id) {
870 if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(context)) {
871 // Conservatively add this global to the interface of every entry point in
872 // the module. This means that the global is available for other
873 // transformations to use.
874 //
875 // A downside of this is that the global will be in the interface even if it
876 // ends up never being used.
877 //
878 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit
879 // this if a more thorough approach to entry point interfaces is taken.
880 for (auto& entry_point : context->module()->entry_points()) {
881 entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {id}});
882 }
883 }
884 }
885
AddGlobalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,SpvStorageClass storage_class,uint32_t initializer_id)886 opt::Instruction* AddGlobalVariable(opt::IRContext* context, uint32_t result_id,
887 uint32_t type_id,
888 SpvStorageClass storage_class,
889 uint32_t initializer_id) {
890 // Check various preconditions.
891 assert(result_id != 0 && "Result id can't be 0");
892
893 assert((storage_class == SpvStorageClassPrivate ||
894 storage_class == SpvStorageClassWorkgroup) &&
895 "Variable's storage class must be either Private or Workgroup");
896
897 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
898 (void)type_inst; // Variable becomes unused in release mode.
899 assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
900 GetStorageClassFromPointerType(type_inst) == storage_class &&
901 "Variable's type is invalid");
902
903 if (storage_class == SpvStorageClassWorkgroup) {
904 assert(initializer_id == 0);
905 }
906
907 if (initializer_id != 0) {
908 const auto* constant_inst =
909 context->get_def_use_mgr()->GetDef(initializer_id);
910 (void)constant_inst; // Variable becomes unused in release mode.
911 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
912 GetPointeeTypeIdFromPointerType(type_inst) ==
913 constant_inst->type_id() &&
914 "Initializer is invalid");
915 }
916
917 opt::Instruction::OperandList operands = {
918 {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast<uint32_t>(storage_class)}}};
919
920 if (initializer_id) {
921 operands.push_back({SPV_OPERAND_TYPE_ID, {initializer_id}});
922 }
923
924 auto new_instruction = MakeUnique<opt::Instruction>(
925 context, SpvOpVariable, type_id, result_id, std::move(operands));
926 auto result = new_instruction.get();
927 context->module()->AddGlobalValue(std::move(new_instruction));
928
929 AddVariableIdToEntryPointInterfaces(context, result_id);
930 UpdateModuleIdBound(context, result_id);
931
932 return result;
933 }
934
AddLocalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,uint32_t function_id,uint32_t initializer_id)935 opt::Instruction* AddLocalVariable(opt::IRContext* context, uint32_t result_id,
936 uint32_t type_id, uint32_t function_id,
937 uint32_t initializer_id) {
938 // Check various preconditions.
939 assert(result_id != 0 && "Result id can't be 0");
940
941 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
942 (void)type_inst; // Variable becomes unused in release mode.
943 assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
944 GetStorageClassFromPointerType(type_inst) == SpvStorageClassFunction &&
945 "Variable's type is invalid");
946
947 const auto* constant_inst =
948 context->get_def_use_mgr()->GetDef(initializer_id);
949 (void)constant_inst; // Variable becomes unused in release mode.
950 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
951 GetPointeeTypeIdFromPointerType(type_inst) ==
952 constant_inst->type_id() &&
953 "Initializer is invalid");
954
955 auto* function = FindFunction(context, function_id);
956 assert(function && "Function id is invalid");
957
958 auto new_instruction = MakeUnique<opt::Instruction>(
959 context, SpvOpVariable, type_id, result_id,
960 opt::Instruction::OperandList{
961 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
962 {SPV_OPERAND_TYPE_ID, {initializer_id}}});
963 auto result = new_instruction.get();
964 function->begin()->begin()->InsertBefore(std::move(new_instruction));
965
966 UpdateModuleIdBound(context, result_id);
967
968 return result;
969 }
970
HasDuplicates(const std::vector<uint32_t> & arr)971 bool HasDuplicates(const std::vector<uint32_t>& arr) {
972 return std::unordered_set<uint32_t>(arr.begin(), arr.end()).size() !=
973 arr.size();
974 }
975
IsPermutationOfRange(const std::vector<uint32_t> & arr,uint32_t lo,uint32_t hi)976 bool IsPermutationOfRange(const std::vector<uint32_t>& arr, uint32_t lo,
977 uint32_t hi) {
978 if (arr.empty()) {
979 return lo > hi;
980 }
981
982 if (HasDuplicates(arr)) {
983 return false;
984 }
985
986 auto min_max = std::minmax_element(arr.begin(), arr.end());
987 return arr.size() == hi - lo + 1 && *min_max.first == lo &&
988 *min_max.second == hi;
989 }
990
GetParameters(opt::IRContext * ir_context,uint32_t function_id)991 std::vector<opt::Instruction*> GetParameters(opt::IRContext* ir_context,
992 uint32_t function_id) {
993 auto* function = FindFunction(ir_context, function_id);
994 assert(function && "|function_id| is invalid");
995
996 std::vector<opt::Instruction*> result;
997 function->ForEachParam(
998 [&result](opt::Instruction* inst) { result.push_back(inst); });
999
1000 return result;
1001 }
1002
RemoveParameter(opt::IRContext * ir_context,uint32_t parameter_id)1003 void RemoveParameter(opt::IRContext* ir_context, uint32_t parameter_id) {
1004 auto* function = GetFunctionFromParameterId(ir_context, parameter_id);
1005 assert(function && "|parameter_id| is invalid");
1006 assert(!FunctionIsEntryPoint(ir_context, function->result_id()) &&
1007 "Can't remove parameter from an entry point function");
1008
1009 function->RemoveParameter(parameter_id);
1010
1011 // We've just removed parameters from the function and cleared their memory.
1012 // Make sure analyses have no dangling pointers.
1013 ir_context->InvalidateAnalysesExceptFor(
1014 opt::IRContext::Analysis::kAnalysisNone);
1015 }
1016
GetCallers(opt::IRContext * ir_context,uint32_t function_id)1017 std::vector<opt::Instruction*> GetCallers(opt::IRContext* ir_context,
1018 uint32_t function_id) {
1019 assert(FindFunction(ir_context, function_id) &&
1020 "|function_id| is not a result id of a function");
1021
1022 std::vector<opt::Instruction*> result;
1023 ir_context->get_def_use_mgr()->ForEachUser(
1024 function_id, [&result, function_id](opt::Instruction* inst) {
1025 if (inst->opcode() == SpvOpFunctionCall &&
1026 inst->GetSingleWordInOperand(0) == function_id) {
1027 result.push_back(inst);
1028 }
1029 });
1030
1031 return result;
1032 }
1033
GetFunctionFromParameterId(opt::IRContext * ir_context,uint32_t param_id)1034 opt::Function* GetFunctionFromParameterId(opt::IRContext* ir_context,
1035 uint32_t param_id) {
1036 auto* param_inst = ir_context->get_def_use_mgr()->GetDef(param_id);
1037 assert(param_inst && "Parameter id is invalid");
1038
1039 for (auto& function : *ir_context->module()) {
1040 if (InstructionIsFunctionParameter(param_inst, &function)) {
1041 return &function;
1042 }
1043 }
1044
1045 return nullptr;
1046 }
1047
UpdateFunctionType(opt::IRContext * ir_context,uint32_t function_id,uint32_t new_function_type_result_id,uint32_t return_type_id,const std::vector<uint32_t> & parameter_type_ids)1048 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
1049 uint32_t new_function_type_result_id,
1050 uint32_t return_type_id,
1051 const std::vector<uint32_t>& parameter_type_ids) {
1052 // Check some initial constraints.
1053 assert(ir_context->get_type_mgr()->GetType(return_type_id) &&
1054 "Return type is invalid");
1055 for (auto id : parameter_type_ids) {
1056 const auto* type = ir_context->get_type_mgr()->GetType(id);
1057 (void)type; // Make compilers happy in release mode.
1058 // Parameters can't be OpTypeVoid.
1059 assert(type && !type->AsVoid() && "Parameter has invalid type");
1060 }
1061
1062 auto* function = FindFunction(ir_context, function_id);
1063 assert(function && "|function_id| is invalid");
1064
1065 auto* old_function_type = GetFunctionType(ir_context, function);
1066 assert(old_function_type && "Function has invalid type");
1067
1068 std::vector<uint32_t> operand_ids = {return_type_id};
1069 operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
1070 parameter_type_ids.end());
1071
1072 // A trivial case - we change nothing.
1073 if (FindFunctionType(ir_context, operand_ids) ==
1074 old_function_type->result_id()) {
1075 return old_function_type->result_id();
1076 }
1077
1078 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
1079 FindFunctionType(ir_context, operand_ids) == 0) {
1080 // We can change |old_function_type| only if it's used once in the module
1081 // and we are certain we won't create a duplicate as a result of the change.
1082
1083 // Update |old_function_type| in-place.
1084 opt::Instruction::OperandList operands;
1085 for (auto id : operand_ids) {
1086 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1087 }
1088
1089 old_function_type->SetInOperands(std::move(operands));
1090
1091 // |operands| may depend on result ids defined below the |old_function_type|
1092 // in the module.
1093 old_function_type->RemoveFromList();
1094 ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
1095 return old_function_type->result_id();
1096 } else {
1097 // We can't modify the |old_function_type| so we have to either use an
1098 // existing one or create a new one.
1099 auto type_id = FindOrCreateFunctionType(
1100 ir_context, new_function_type_result_id, operand_ids);
1101 assert(type_id != old_function_type->result_id() &&
1102 "We should've handled this case above");
1103
1104 function->DefInst().SetInOperand(1, {type_id});
1105
1106 // DefUseManager hasn't been updated yet, so if the following condition is
1107 // true, then |old_function_type| will have no users when this function
1108 // returns. We might as well remove it.
1109 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
1110 ir_context->KillInst(old_function_type);
1111 }
1112
1113 return type_id;
1114 }
1115 }
1116
AddFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1117 void AddFunctionType(opt::IRContext* ir_context, uint32_t result_id,
1118 const std::vector<uint32_t>& type_ids) {
1119 assert(result_id != 0 && "Result id can't be 0");
1120 assert(!type_ids.empty() &&
1121 "OpTypeFunction always has at least one operand - function's return "
1122 "type");
1123 assert(IsNonFunctionTypeId(ir_context, type_ids[0]) &&
1124 "Return type must not be a function");
1125
1126 for (size_t i = 1; i < type_ids.size(); ++i) {
1127 const auto* param_type = ir_context->get_type_mgr()->GetType(type_ids[i]);
1128 (void)param_type; // Make compiler happy in release mode.
1129 assert(param_type && !param_type->AsVoid() && !param_type->AsFunction() &&
1130 "Function parameter can't have a function or void type");
1131 }
1132
1133 opt::Instruction::OperandList operands;
1134 operands.reserve(type_ids.size());
1135 for (auto id : type_ids) {
1136 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1137 }
1138
1139 ir_context->AddType(MakeUnique<opt::Instruction>(
1140 ir_context, SpvOpTypeFunction, 0, result_id, std::move(operands)));
1141
1142 UpdateModuleIdBound(ir_context, result_id);
1143 }
1144
FindOrCreateFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1145 uint32_t FindOrCreateFunctionType(opt::IRContext* ir_context,
1146 uint32_t result_id,
1147 const std::vector<uint32_t>& type_ids) {
1148 if (auto existing_id = FindFunctionType(ir_context, type_ids)) {
1149 return existing_id;
1150 }
1151 AddFunctionType(ir_context, result_id, type_ids);
1152 return result_id;
1153 }
1154
MaybeGetIntegerType(opt::IRContext * ir_context,uint32_t width,bool is_signed)1155 uint32_t MaybeGetIntegerType(opt::IRContext* ir_context, uint32_t width,
1156 bool is_signed) {
1157 opt::analysis::Integer type(width, is_signed);
1158 return ir_context->get_type_mgr()->GetId(&type);
1159 }
1160
MaybeGetFloatType(opt::IRContext * ir_context,uint32_t width)1161 uint32_t MaybeGetFloatType(opt::IRContext* ir_context, uint32_t width) {
1162 opt::analysis::Float type(width);
1163 return ir_context->get_type_mgr()->GetId(&type);
1164 }
1165
MaybeGetBoolType(opt::IRContext * ir_context)1166 uint32_t MaybeGetBoolType(opt::IRContext* ir_context) {
1167 opt::analysis::Bool type;
1168 return ir_context->get_type_mgr()->GetId(&type);
1169 }
1170
MaybeGetVectorType(opt::IRContext * ir_context,uint32_t component_type_id,uint32_t element_count)1171 uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
1172 uint32_t component_type_id,
1173 uint32_t element_count) {
1174 const auto* component_type =
1175 ir_context->get_type_mgr()->GetType(component_type_id);
1176 assert(component_type &&
1177 (component_type->AsInteger() || component_type->AsFloat() ||
1178 component_type->AsBool()) &&
1179 "|component_type_id| is invalid");
1180 assert(element_count >= 2 && element_count <= 4 &&
1181 "Precondition: component count must be in range [2, 4].");
1182 opt::analysis::Vector type(component_type, element_count);
1183 return ir_context->get_type_mgr()->GetId(&type);
1184 }
1185
MaybeGetStructType(opt::IRContext * ir_context,const std::vector<uint32_t> & component_type_ids)1186 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
1187 const std::vector<uint32_t>& component_type_ids) {
1188 for (auto& type_or_value : ir_context->types_values()) {
1189 if (type_or_value.opcode() != SpvOpTypeStruct ||
1190 type_or_value.NumInOperands() !=
1191 static_cast<uint32_t>(component_type_ids.size())) {
1192 continue;
1193 }
1194 bool all_components_match = true;
1195 for (uint32_t i = 0; i < component_type_ids.size(); i++) {
1196 if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
1197 all_components_match = false;
1198 break;
1199 }
1200 }
1201 if (all_components_match) {
1202 return type_or_value.result_id();
1203 }
1204 }
1205 return 0;
1206 }
1207
MaybeGetVoidType(opt::IRContext * ir_context)1208 uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
1209 opt::analysis::Void type;
1210 return ir_context->get_type_mgr()->GetId(&type);
1211 }
1212
MaybeGetZeroConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,uint32_t scalar_or_composite_type_id,bool is_irrelevant)1213 uint32_t MaybeGetZeroConstant(
1214 opt::IRContext* ir_context,
1215 const TransformationContext& transformation_context,
1216 uint32_t scalar_or_composite_type_id, bool is_irrelevant) {
1217 const auto* type_inst =
1218 ir_context->get_def_use_mgr()->GetDef(scalar_or_composite_type_id);
1219 assert(type_inst && "|scalar_or_composite_type_id| is invalid");
1220
1221 switch (type_inst->opcode()) {
1222 case SpvOpTypeBool:
1223 return MaybeGetBoolConstant(ir_context, transformation_context, false,
1224 is_irrelevant);
1225 case SpvOpTypeFloat:
1226 case SpvOpTypeInt: {
1227 const auto width = type_inst->GetSingleWordInOperand(0);
1228 std::vector<uint32_t> words = {0};
1229 if (width > 32) {
1230 words.push_back(0);
1231 }
1232
1233 return MaybeGetScalarConstant(ir_context, transformation_context, words,
1234 scalar_or_composite_type_id, is_irrelevant);
1235 }
1236 case SpvOpTypeStruct: {
1237 std::vector<uint32_t> component_ids;
1238 for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
1239 const auto component_type_id = type_inst->GetSingleWordInOperand(i);
1240
1241 auto component_id =
1242 MaybeGetZeroConstant(ir_context, transformation_context,
1243 component_type_id, is_irrelevant);
1244
1245 if (component_id == 0 && is_irrelevant) {
1246 // Irrelevant constants can use either relevant or irrelevant
1247 // constituents.
1248 component_id = MaybeGetZeroConstant(
1249 ir_context, transformation_context, component_type_id, false);
1250 }
1251
1252 if (component_id == 0) {
1253 return 0;
1254 }
1255
1256 component_ids.push_back(component_id);
1257 }
1258
1259 return MaybeGetCompositeConstant(
1260 ir_context, transformation_context, component_ids,
1261 scalar_or_composite_type_id, is_irrelevant);
1262 }
1263 case SpvOpTypeMatrix:
1264 case SpvOpTypeVector: {
1265 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1266
1267 auto component_id = MaybeGetZeroConstant(
1268 ir_context, transformation_context, component_type_id, is_irrelevant);
1269
1270 if (component_id == 0 && is_irrelevant) {
1271 // Irrelevant constants can use either relevant or irrelevant
1272 // constituents.
1273 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1274 component_type_id, false);
1275 }
1276
1277 if (component_id == 0) {
1278 return 0;
1279 }
1280
1281 const auto component_count = type_inst->GetSingleWordInOperand(1);
1282 return MaybeGetCompositeConstant(
1283 ir_context, transformation_context,
1284 std::vector<uint32_t>(component_count, component_id),
1285 scalar_or_composite_type_id, is_irrelevant);
1286 }
1287 case SpvOpTypeArray: {
1288 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1289
1290 auto component_id = MaybeGetZeroConstant(
1291 ir_context, transformation_context, component_type_id, is_irrelevant);
1292
1293 if (component_id == 0 && is_irrelevant) {
1294 // Irrelevant constants can use either relevant or irrelevant
1295 // constituents.
1296 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1297 component_type_id, false);
1298 }
1299
1300 if (component_id == 0) {
1301 return 0;
1302 }
1303
1304 return MaybeGetCompositeConstant(
1305 ir_context, transformation_context,
1306 std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
1307 component_id),
1308 scalar_or_composite_type_id, is_irrelevant);
1309 }
1310 default:
1311 assert(false && "Type is not supported");
1312 return 0;
1313 }
1314 }
1315
CanCreateConstant(opt::IRContext * ir_context,uint32_t type_id)1316 bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
1317 opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
1318 assert(type_instr != nullptr && "The type must exist.");
1319 assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
1320 "A type-generating opcode was expected.");
1321 switch (type_instr->opcode()) {
1322 case SpvOpTypeBool:
1323 case SpvOpTypeInt:
1324 case SpvOpTypeFloat:
1325 case SpvOpTypeMatrix:
1326 case SpvOpTypeVector:
1327 return true;
1328 case SpvOpTypeArray:
1329 return CanCreateConstant(ir_context,
1330 type_instr->GetSingleWordInOperand(0));
1331 case SpvOpTypeStruct:
1332 if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
1333 return false;
1334 }
1335 for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
1336 if (!CanCreateConstant(ir_context,
1337 type_instr->GetSingleWordInOperand(index))) {
1338 return false;
1339 }
1340 }
1341 return true;
1342 default:
1343 return false;
1344 }
1345 }
1346
MaybeGetScalarConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t scalar_type_id,bool is_irrelevant)1347 uint32_t MaybeGetScalarConstant(
1348 opt::IRContext* ir_context,
1349 const TransformationContext& transformation_context,
1350 const std::vector<uint32_t>& words, uint32_t scalar_type_id,
1351 bool is_irrelevant) {
1352 const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
1353 assert(type && "|scalar_type_id| is invalid");
1354
1355 if (const auto* int_type = type->AsInteger()) {
1356 return MaybeGetIntegerConstant(ir_context, transformation_context, words,
1357 int_type->width(), int_type->IsSigned(),
1358 is_irrelevant);
1359 } else if (const auto* float_type = type->AsFloat()) {
1360 return MaybeGetFloatConstant(ir_context, transformation_context, words,
1361 float_type->width(), is_irrelevant);
1362 } else {
1363 assert(type->AsBool() && words.size() == 1 &&
1364 "|scalar_type_id| doesn't represent a scalar type");
1365 return MaybeGetBoolConstant(ir_context, transformation_context, words[0],
1366 is_irrelevant);
1367 }
1368 }
1369
MaybeGetCompositeConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & component_ids,uint32_t composite_type_id,bool is_irrelevant)1370 uint32_t MaybeGetCompositeConstant(
1371 opt::IRContext* ir_context,
1372 const TransformationContext& transformation_context,
1373 const std::vector<uint32_t>& component_ids, uint32_t composite_type_id,
1374 bool is_irrelevant) {
1375 const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
1376 (void)type; // Make compilers happy in release mode.
1377 assert(IsCompositeType(type) && "|composite_type_id| is invalid");
1378
1379 for (const auto& inst : ir_context->types_values()) {
1380 if (inst.opcode() == SpvOpConstantComposite &&
1381 inst.type_id() == composite_type_id &&
1382 transformation_context.GetFactManager()->IdIsIrrelevant(
1383 inst.result_id()) == is_irrelevant &&
1384 inst.NumInOperands() == component_ids.size()) {
1385 bool is_match = true;
1386
1387 for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
1388 if (inst.GetSingleWordInOperand(i) != component_ids[i]) {
1389 is_match = false;
1390 break;
1391 }
1392 }
1393
1394 if (is_match) {
1395 return inst.result_id();
1396 }
1397 }
1398 }
1399
1400 return 0;
1401 }
1402
MaybeGetIntegerConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_signed,bool is_irrelevant)1403 uint32_t MaybeGetIntegerConstant(
1404 opt::IRContext* ir_context,
1405 const TransformationContext& transformation_context,
1406 const std::vector<uint32_t>& words, uint32_t width, bool is_signed,
1407 bool is_irrelevant) {
1408 if (auto type_id = MaybeGetIntegerType(ir_context, width, is_signed)) {
1409 return MaybeGetOpConstant(ir_context, transformation_context, words,
1410 type_id, is_irrelevant);
1411 }
1412
1413 return 0;
1414 }
1415
MaybeGetIntegerConstantFromValueAndType(opt::IRContext * ir_context,uint32_t value,uint32_t int_type_id)1416 uint32_t MaybeGetIntegerConstantFromValueAndType(opt::IRContext* ir_context,
1417 uint32_t value,
1418 uint32_t int_type_id) {
1419 auto int_type_inst = ir_context->get_def_use_mgr()->GetDef(int_type_id);
1420
1421 assert(int_type_inst && "The given type id must exist.");
1422
1423 auto int_type = ir_context->get_type_mgr()
1424 ->GetType(int_type_inst->result_id())
1425 ->AsInteger();
1426
1427 assert(int_type && int_type->width() == 32 &&
1428 "The given type id must correspond to an 32-bit integer type.");
1429
1430 opt::analysis::IntConstant constant(int_type, {value});
1431
1432 // Check that the constant exists in the module.
1433 if (!ir_context->get_constant_mgr()->FindConstant(&constant)) {
1434 return 0;
1435 }
1436
1437 return ir_context->get_constant_mgr()
1438 ->GetDefiningInstruction(&constant)
1439 ->result_id();
1440 }
1441
MaybeGetFloatConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_irrelevant)1442 uint32_t MaybeGetFloatConstant(
1443 opt::IRContext* ir_context,
1444 const TransformationContext& transformation_context,
1445 const std::vector<uint32_t>& words, uint32_t width, bool is_irrelevant) {
1446 if (auto type_id = MaybeGetFloatType(ir_context, width)) {
1447 return MaybeGetOpConstant(ir_context, transformation_context, words,
1448 type_id, is_irrelevant);
1449 }
1450
1451 return 0;
1452 }
1453
MaybeGetBoolConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,bool value,bool is_irrelevant)1454 uint32_t MaybeGetBoolConstant(
1455 opt::IRContext* ir_context,
1456 const TransformationContext& transformation_context, bool value,
1457 bool is_irrelevant) {
1458 if (auto type_id = MaybeGetBoolType(ir_context)) {
1459 for (const auto& inst : ir_context->types_values()) {
1460 if (inst.opcode() == (value ? SpvOpConstantTrue : SpvOpConstantFalse) &&
1461 inst.type_id() == type_id &&
1462 transformation_context.GetFactManager()->IdIsIrrelevant(
1463 inst.result_id()) == is_irrelevant) {
1464 return inst.result_id();
1465 }
1466 }
1467 }
1468
1469 return 0;
1470 }
1471
IntToWords(uint64_t value,uint32_t width,bool is_signed)1472 std::vector<uint32_t> IntToWords(uint64_t value, uint32_t width,
1473 bool is_signed) {
1474 assert(width <= 64 && "The bit width should not be more than 64 bits");
1475
1476 // Sign-extend or zero-extend the last |width| bits of |value|, depending on
1477 // |is_signed|.
1478 if (is_signed) {
1479 // Sign-extend by shifting left and then shifting right, interpreting the
1480 // integer as signed.
1481 value = static_cast<int64_t>(value << (64 - width)) >> (64 - width);
1482 } else {
1483 // Zero-extend by shifting left and then shifting right, interpreting the
1484 // integer as unsigned.
1485 value = (value << (64 - width)) >> (64 - width);
1486 }
1487
1488 std::vector<uint32_t> result;
1489 result.push_back(static_cast<uint32_t>(value));
1490 if (width > 32) {
1491 result.push_back(static_cast<uint32_t>(value >> 32));
1492 }
1493 return result;
1494 }
1495
TypesAreEqualUpToSign(opt::IRContext * ir_context,uint32_t type1_id,uint32_t type2_id)1496 bool TypesAreEqualUpToSign(opt::IRContext* ir_context, uint32_t type1_id,
1497 uint32_t type2_id) {
1498 if (type1_id == type2_id) {
1499 return true;
1500 }
1501
1502 auto type1 = ir_context->get_type_mgr()->GetType(type1_id);
1503 auto type2 = ir_context->get_type_mgr()->GetType(type2_id);
1504
1505 // Integer scalar types must have the same width
1506 if (type1->AsInteger() && type2->AsInteger()) {
1507 return type1->AsInteger()->width() == type2->AsInteger()->width();
1508 }
1509
1510 // Integer vector types must have the same number of components and their
1511 // component types must be integers with the same width.
1512 if (type1->AsVector() && type2->AsVector()) {
1513 auto component_type1 = type1->AsVector()->element_type()->AsInteger();
1514 auto component_type2 = type2->AsVector()->element_type()->AsInteger();
1515
1516 // Only check the component count and width if they are integer.
1517 if (component_type1 && component_type2) {
1518 return type1->AsVector()->element_count() ==
1519 type2->AsVector()->element_count() &&
1520 component_type1->width() == component_type2->width();
1521 }
1522 }
1523
1524 // In all other cases, the types cannot be considered equal.
1525 return false;
1526 }
1527
RepeatedUInt32PairToMap(const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> & data)1528 std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
1529 const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>& data) {
1530 std::map<uint32_t, uint32_t> result;
1531
1532 for (const auto& entry : data) {
1533 result[entry.first()] = entry.second();
1534 }
1535
1536 return result;
1537 }
1538
1539 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t,uint32_t> & data)1540 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
1541 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> result;
1542
1543 for (const auto& entry : data) {
1544 protobufs::UInt32Pair pair;
1545 pair.set_first(entry.first);
1546 pair.set_second(entry.second);
1547 *result.Add() = std::move(pair);
1548 }
1549
1550 return result;
1551 }
1552
GetLastInsertBeforeInstruction(opt::IRContext * ir_context,uint32_t block_id,SpvOp opcode)1553 opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
1554 uint32_t block_id,
1555 SpvOp opcode) {
1556 // CFG::block uses std::map::at which throws an exception when |block_id| is
1557 // invalid. The error message is unhelpful, though. Thus, we test that
1558 // |block_id| is valid here.
1559 const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
1560 (void)label_inst; // Make compilers happy in release mode.
1561 assert(label_inst && label_inst->opcode() == SpvOpLabel &&
1562 "|block_id| is invalid");
1563
1564 auto* block = ir_context->cfg()->block(block_id);
1565 auto it = block->rbegin();
1566 assert(it != block->rend() && "Basic block can't be empty");
1567
1568 if (block->GetMergeInst()) {
1569 ++it;
1570 assert(it != block->rend() &&
1571 "|block| must have at least two instructions:"
1572 "terminator and a merge instruction");
1573 }
1574
1575 return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
1576 }
1577
IdUseCanBeReplaced(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * use_instruction,uint32_t use_in_operand_index)1578 bool IdUseCanBeReplaced(opt::IRContext* ir_context,
1579 const TransformationContext& transformation_context,
1580 opt::Instruction* use_instruction,
1581 uint32_t use_in_operand_index) {
1582 if (spvOpcodeIsAccessChain(use_instruction->opcode()) &&
1583 use_in_operand_index > 0) {
1584 // A replacement for an irrelevant index in OpAccessChain must be clamped
1585 // first.
1586 if (transformation_context.GetFactManager()->IdIsIrrelevant(
1587 use_instruction->GetSingleWordInOperand(use_in_operand_index))) {
1588 return false;
1589 }
1590
1591 // This is an access chain index. If the (sub-)object being accessed by the
1592 // given index has struct type then we cannot replace the use, as it needs
1593 // to be an OpConstant.
1594
1595 // Get the top-level composite type that is being accessed.
1596 auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef(
1597 use_instruction->GetSingleWordInOperand(0));
1598 auto pointer_type =
1599 ir_context->get_type_mgr()->GetType(object_being_accessed->type_id());
1600 assert(pointer_type->AsPointer());
1601 auto composite_type_being_accessed =
1602 pointer_type->AsPointer()->pointee_type();
1603
1604 // Now walk the access chain, tracking the type of each sub-object of the
1605 // composite that is traversed, until the index of interest is reached.
1606 for (uint32_t index_in_operand = 1; index_in_operand < use_in_operand_index;
1607 index_in_operand++) {
1608 // For vectors, matrices and arrays, getting the type of the sub-object is
1609 // trivial. For the struct case, the sub-object type is field-sensitive,
1610 // and depends on the constant index that is used.
1611 if (composite_type_being_accessed->AsVector()) {
1612 composite_type_being_accessed =
1613 composite_type_being_accessed->AsVector()->element_type();
1614 } else if (composite_type_being_accessed->AsMatrix()) {
1615 composite_type_being_accessed =
1616 composite_type_being_accessed->AsMatrix()->element_type();
1617 } else if (composite_type_being_accessed->AsArray()) {
1618 composite_type_being_accessed =
1619 composite_type_being_accessed->AsArray()->element_type();
1620 } else if (composite_type_being_accessed->AsRuntimeArray()) {
1621 composite_type_being_accessed =
1622 composite_type_being_accessed->AsRuntimeArray()->element_type();
1623 } else {
1624 assert(composite_type_being_accessed->AsStruct());
1625 auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef(
1626 use_instruction->GetSingleWordInOperand(index_in_operand));
1627 assert(constant_index_instruction->opcode() == SpvOpConstant);
1628 uint32_t member_index =
1629 constant_index_instruction->GetSingleWordInOperand(0);
1630 composite_type_being_accessed =
1631 composite_type_being_accessed->AsStruct()
1632 ->element_types()[member_index];
1633 }
1634 }
1635
1636 // We have found the composite type being accessed by the index we are
1637 // considering replacing. If it is a struct, then we cannot do the
1638 // replacement as struct indices must be constants.
1639 if (composite_type_being_accessed->AsStruct()) {
1640 return false;
1641 }
1642 }
1643
1644 if (use_instruction->opcode() == SpvOpFunctionCall &&
1645 use_in_operand_index > 0) {
1646 // This is a function call argument. It is not allowed to have pointer
1647 // type.
1648
1649 // Get the definition of the function being called.
1650 auto function = ir_context->get_def_use_mgr()->GetDef(
1651 use_instruction->GetSingleWordInOperand(0));
1652 // From the function definition, get the function type.
1653 auto function_type = ir_context->get_def_use_mgr()->GetDef(
1654 function->GetSingleWordInOperand(1));
1655 // OpTypeFunction's 0-th input operand is the function return type, and the
1656 // function argument types follow. Because the arguments to OpFunctionCall
1657 // start from input operand 1, we can use |use_in_operand_index| to get the
1658 // type associated with this function argument.
1659 auto parameter_type = ir_context->get_type_mgr()->GetType(
1660 function_type->GetSingleWordInOperand(use_in_operand_index));
1661 if (parameter_type->AsPointer()) {
1662 return false;
1663 }
1664 }
1665
1666 if (use_instruction->opcode() == SpvOpImageTexelPointer &&
1667 use_in_operand_index == 2) {
1668 // The OpImageTexelPointer instruction has a Sample parameter that in some
1669 // situations must be an id for the value 0. To guard against disrupting
1670 // that requirement, we do not replace this argument to that instruction.
1671 return false;
1672 }
1673
1674 if (ir_context->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
1675 // With the Shader capability, memory scope and memory semantics operands
1676 // are required to be constants, so they cannot be replaced arbitrarily.
1677 switch (use_instruction->opcode()) {
1678 case SpvOpAtomicLoad:
1679 case SpvOpAtomicStore:
1680 case SpvOpAtomicExchange:
1681 case SpvOpAtomicIIncrement:
1682 case SpvOpAtomicIDecrement:
1683 case SpvOpAtomicIAdd:
1684 case SpvOpAtomicISub:
1685 case SpvOpAtomicSMin:
1686 case SpvOpAtomicUMin:
1687 case SpvOpAtomicSMax:
1688 case SpvOpAtomicUMax:
1689 case SpvOpAtomicAnd:
1690 case SpvOpAtomicOr:
1691 case SpvOpAtomicXor:
1692 if (use_in_operand_index == 1 || use_in_operand_index == 2) {
1693 return false;
1694 }
1695 break;
1696 case SpvOpAtomicCompareExchange:
1697 if (use_in_operand_index == 1 || use_in_operand_index == 2 ||
1698 use_in_operand_index == 3) {
1699 return false;
1700 }
1701 break;
1702 case SpvOpAtomicCompareExchangeWeak:
1703 case SpvOpAtomicFlagTestAndSet:
1704 case SpvOpAtomicFlagClear:
1705 case SpvOpAtomicFAddEXT:
1706 assert(false && "Not allowed with the Shader capability.");
1707 default:
1708 break;
1709 }
1710 }
1711
1712 return true;
1713 }
1714
MembersHaveBuiltInDecoration(opt::IRContext * ir_context,uint32_t struct_type_id)1715 bool MembersHaveBuiltInDecoration(opt::IRContext* ir_context,
1716 uint32_t struct_type_id) {
1717 const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(struct_type_id);
1718 assert(type_inst && type_inst->opcode() == SpvOpTypeStruct &&
1719 "|struct_type_id| is not a result id of an OpTypeStruct");
1720
1721 uint32_t builtin_count = 0;
1722 ir_context->get_def_use_mgr()->ForEachUser(
1723 type_inst,
1724 [struct_type_id, &builtin_count](const opt::Instruction* user) {
1725 if (user->opcode() == SpvOpMemberDecorate &&
1726 user->GetSingleWordInOperand(0) == struct_type_id &&
1727 static_cast<SpvDecoration>(user->GetSingleWordInOperand(2)) ==
1728 SpvDecorationBuiltIn) {
1729 ++builtin_count;
1730 }
1731 });
1732
1733 assert((builtin_count == 0 || builtin_count == type_inst->NumInOperands()) &&
1734 "The module is invalid: either none or all of the members of "
1735 "|struct_type_id| may be builtin");
1736
1737 return builtin_count != 0;
1738 }
1739
HasBlockOrBufferBlockDecoration(opt::IRContext * ir_context,uint32_t id)1740 bool HasBlockOrBufferBlockDecoration(opt::IRContext* ir_context, uint32_t id) {
1741 for (auto decoration : {SpvDecorationBlock, SpvDecorationBufferBlock}) {
1742 if (!ir_context->get_decoration_mgr()->WhileEachDecoration(
1743 id, decoration, [](const opt::Instruction & /*unused*/) -> bool {
1744 return false;
1745 })) {
1746 return true;
1747 }
1748 }
1749 return false;
1750 }
1751
SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(opt::BasicBlock * block_to_split,opt::Instruction * split_before)1752 bool SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(
1753 opt::BasicBlock* block_to_split, opt::Instruction* split_before) {
1754 std::set<uint32_t> sampled_image_result_ids;
1755 bool before_split = true;
1756
1757 // Check all the instructions in the block to split.
1758 for (auto& instruction : *block_to_split) {
1759 if (&instruction == &*split_before) {
1760 before_split = false;
1761 }
1762 if (before_split) {
1763 // If the instruction comes before the split and its opcode is
1764 // OpSampledImage, record its result id.
1765 if (instruction.opcode() == SpvOpSampledImage) {
1766 sampled_image_result_ids.insert(instruction.result_id());
1767 }
1768 } else {
1769 // If the instruction comes after the split, check if ids
1770 // corresponding to OpSampledImage instructions defined before the split
1771 // are used, and return true if they are.
1772 if (!instruction.WhileEachInId(
1773 [&sampled_image_result_ids](uint32_t* id) -> bool {
1774 return !sampled_image_result_ids.count(*id);
1775 })) {
1776 return true;
1777 }
1778 }
1779 }
1780
1781 // No usage that would be separated from the definition has been found.
1782 return false;
1783 }
1784
InstructionHasNoSideEffects(const opt::Instruction & instruction)1785 bool InstructionHasNoSideEffects(const opt::Instruction& instruction) {
1786 switch (instruction.opcode()) {
1787 case SpvOpUndef:
1788 case SpvOpAccessChain:
1789 case SpvOpInBoundsAccessChain:
1790 case SpvOpArrayLength:
1791 case SpvOpVectorExtractDynamic:
1792 case SpvOpVectorInsertDynamic:
1793 case SpvOpVectorShuffle:
1794 case SpvOpCompositeConstruct:
1795 case SpvOpCompositeExtract:
1796 case SpvOpCompositeInsert:
1797 case SpvOpCopyObject:
1798 case SpvOpTranspose:
1799 case SpvOpConvertFToU:
1800 case SpvOpConvertFToS:
1801 case SpvOpConvertSToF:
1802 case SpvOpConvertUToF:
1803 case SpvOpUConvert:
1804 case SpvOpSConvert:
1805 case SpvOpFConvert:
1806 case SpvOpQuantizeToF16:
1807 case SpvOpSatConvertSToU:
1808 case SpvOpSatConvertUToS:
1809 case SpvOpBitcast:
1810 case SpvOpSNegate:
1811 case SpvOpFNegate:
1812 case SpvOpIAdd:
1813 case SpvOpFAdd:
1814 case SpvOpISub:
1815 case SpvOpFSub:
1816 case SpvOpIMul:
1817 case SpvOpFMul:
1818 case SpvOpUDiv:
1819 case SpvOpSDiv:
1820 case SpvOpFDiv:
1821 case SpvOpUMod:
1822 case SpvOpSRem:
1823 case SpvOpSMod:
1824 case SpvOpFRem:
1825 case SpvOpFMod:
1826 case SpvOpVectorTimesScalar:
1827 case SpvOpMatrixTimesScalar:
1828 case SpvOpVectorTimesMatrix:
1829 case SpvOpMatrixTimesVector:
1830 case SpvOpMatrixTimesMatrix:
1831 case SpvOpOuterProduct:
1832 case SpvOpDot:
1833 case SpvOpIAddCarry:
1834 case SpvOpISubBorrow:
1835 case SpvOpUMulExtended:
1836 case SpvOpSMulExtended:
1837 case SpvOpAny:
1838 case SpvOpAll:
1839 case SpvOpIsNan:
1840 case SpvOpIsInf:
1841 case SpvOpIsFinite:
1842 case SpvOpIsNormal:
1843 case SpvOpSignBitSet:
1844 case SpvOpLessOrGreater:
1845 case SpvOpOrdered:
1846 case SpvOpUnordered:
1847 case SpvOpLogicalEqual:
1848 case SpvOpLogicalNotEqual:
1849 case SpvOpLogicalOr:
1850 case SpvOpLogicalAnd:
1851 case SpvOpLogicalNot:
1852 case SpvOpSelect:
1853 case SpvOpIEqual:
1854 case SpvOpINotEqual:
1855 case SpvOpUGreaterThan:
1856 case SpvOpSGreaterThan:
1857 case SpvOpUGreaterThanEqual:
1858 case SpvOpSGreaterThanEqual:
1859 case SpvOpULessThan:
1860 case SpvOpSLessThan:
1861 case SpvOpULessThanEqual:
1862 case SpvOpSLessThanEqual:
1863 case SpvOpFOrdEqual:
1864 case SpvOpFUnordEqual:
1865 case SpvOpFOrdNotEqual:
1866 case SpvOpFUnordNotEqual:
1867 case SpvOpFOrdLessThan:
1868 case SpvOpFUnordLessThan:
1869 case SpvOpFOrdGreaterThan:
1870 case SpvOpFUnordGreaterThan:
1871 case SpvOpFOrdLessThanEqual:
1872 case SpvOpFUnordLessThanEqual:
1873 case SpvOpFOrdGreaterThanEqual:
1874 case SpvOpFUnordGreaterThanEqual:
1875 case SpvOpShiftRightLogical:
1876 case SpvOpShiftRightArithmetic:
1877 case SpvOpShiftLeftLogical:
1878 case SpvOpBitwiseOr:
1879 case SpvOpBitwiseXor:
1880 case SpvOpBitwiseAnd:
1881 case SpvOpNot:
1882 case SpvOpBitFieldInsert:
1883 case SpvOpBitFieldSExtract:
1884 case SpvOpBitFieldUExtract:
1885 case SpvOpBitReverse:
1886 case SpvOpBitCount:
1887 case SpvOpCopyLogical:
1888 case SpvOpPhi:
1889 case SpvOpPtrEqual:
1890 case SpvOpPtrNotEqual:
1891 return true;
1892 default:
1893 return false;
1894 }
1895 }
1896
GetReachableReturnBlocks(opt::IRContext * ir_context,uint32_t function_id)1897 std::set<uint32_t> GetReachableReturnBlocks(opt::IRContext* ir_context,
1898 uint32_t function_id) {
1899 auto function = ir_context->GetFunction(function_id);
1900 assert(function && "The function |function_id| must exist.");
1901
1902 std::set<uint32_t> result;
1903
1904 ir_context->cfg()->ForEachBlockInPostOrder(function->entry().get(),
1905 [&result](opt::BasicBlock* block) {
1906 if (block->IsReturn()) {
1907 result.emplace(block->id());
1908 }
1909 });
1910
1911 return result;
1912 }
1913
NewTerminatorPreservesDominationRules(opt::IRContext * ir_context,uint32_t block_id,opt::Instruction new_terminator)1914 bool NewTerminatorPreservesDominationRules(opt::IRContext* ir_context,
1915 uint32_t block_id,
1916 opt::Instruction new_terminator) {
1917 auto* mutated_block = MaybeFindBlock(ir_context, block_id);
1918 assert(mutated_block && "|block_id| is invalid");
1919
1920 ChangeTerminatorRAII change_terminator_raii(mutated_block,
1921 std::move(new_terminator));
1922 opt::DominatorAnalysis dominator_analysis;
1923 dominator_analysis.InitializeTree(*ir_context->cfg(),
1924 mutated_block->GetParent());
1925
1926 // Check that each dominator appears before each dominated block.
1927 std::unordered_map<uint32_t, size_t> positions;
1928 for (const auto& block : *mutated_block->GetParent()) {
1929 positions[block.id()] = positions.size();
1930 }
1931
1932 std::queue<uint32_t> q({mutated_block->GetParent()->begin()->id()});
1933 std::unordered_set<uint32_t> visited;
1934 while (!q.empty()) {
1935 auto block = q.front();
1936 q.pop();
1937 visited.insert(block);
1938
1939 auto success = ir_context->cfg()->block(block)->WhileEachSuccessorLabel(
1940 [&positions, &visited, &dominator_analysis, block, &q](uint32_t id) {
1941 if (id == block) {
1942 // Handle the case when loop header and continue target are the same
1943 // block.
1944 return true;
1945 }
1946
1947 if (dominator_analysis.Dominates(block, id) &&
1948 positions[block] > positions[id]) {
1949 // |block| dominates |id| but appears after |id| - violates
1950 // domination rules.
1951 return false;
1952 }
1953
1954 if (!visited.count(id)) {
1955 q.push(id);
1956 }
1957
1958 return true;
1959 });
1960
1961 if (!success) {
1962 return false;
1963 }
1964 }
1965
1966 // For each instruction in the |block->GetParent()| function check whether
1967 // all its dependencies satisfy domination rules (i.e. all id operands
1968 // dominate that instruction).
1969 for (const auto& block : *mutated_block->GetParent()) {
1970 if (!ir_context->IsReachable(block)) {
1971 // If some block is not reachable then we don't need to worry about the
1972 // preservation of domination rules for its instructions.
1973 continue;
1974 }
1975
1976 for (const auto& inst : block) {
1977 for (uint32_t i = 0; i < inst.NumInOperands();
1978 i += inst.opcode() == SpvOpPhi ? 2 : 1) {
1979 const auto& operand = inst.GetInOperand(i);
1980 if (!spvIsInIdType(operand.type)) {
1981 continue;
1982 }
1983
1984 if (MaybeFindBlock(ir_context, operand.words[0])) {
1985 // Ignore operands that refer to OpLabel instructions.
1986 continue;
1987 }
1988
1989 const auto* dependency_block =
1990 ir_context->get_instr_block(operand.words[0]);
1991 if (!dependency_block) {
1992 // A global instruction always dominates all instructions in any
1993 // function.
1994 continue;
1995 }
1996
1997 auto domination_target_id = inst.opcode() == SpvOpPhi
1998 ? inst.GetSingleWordInOperand(i + 1)
1999 : block.id();
2000
2001 if (!dominator_analysis.Dominates(dependency_block->id(),
2002 domination_target_id)) {
2003 return false;
2004 }
2005 }
2006 }
2007 }
2008
2009 return true;
2010 }
2011
GetFunctionIterator(opt::IRContext * ir_context,uint32_t function_id)2012 opt::Module::iterator GetFunctionIterator(opt::IRContext* ir_context,
2013 uint32_t function_id) {
2014 return std::find_if(ir_context->module()->begin(),
2015 ir_context->module()->end(),
2016 [function_id](const opt::Function& f) {
2017 return f.result_id() == function_id;
2018 });
2019 }
2020
2021 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3582): Add all
2022 // opcodes that are agnostic to signedness of operands to function.
2023 // This is not exhaustive yet.
IsAgnosticToSignednessOfOperand(SpvOp opcode,uint32_t use_in_operand_index)2024 bool IsAgnosticToSignednessOfOperand(SpvOp opcode,
2025 uint32_t use_in_operand_index) {
2026 switch (opcode) {
2027 case SpvOpSNegate:
2028 case SpvOpNot:
2029 case SpvOpIAdd:
2030 case SpvOpISub:
2031 case SpvOpIMul:
2032 case SpvOpSDiv:
2033 case SpvOpSRem:
2034 case SpvOpSMod:
2035 case SpvOpShiftRightLogical:
2036 case SpvOpShiftRightArithmetic:
2037 case SpvOpShiftLeftLogical:
2038 case SpvOpBitwiseOr:
2039 case SpvOpBitwiseXor:
2040 case SpvOpBitwiseAnd:
2041 case SpvOpIEqual:
2042 case SpvOpINotEqual:
2043 case SpvOpULessThan:
2044 case SpvOpSLessThan:
2045 case SpvOpUGreaterThan:
2046 case SpvOpSGreaterThan:
2047 case SpvOpULessThanEqual:
2048 case SpvOpSLessThanEqual:
2049 case SpvOpUGreaterThanEqual:
2050 case SpvOpSGreaterThanEqual:
2051 return true;
2052
2053 case SpvOpAtomicStore:
2054 case SpvOpAtomicExchange:
2055 case SpvOpAtomicIAdd:
2056 case SpvOpAtomicISub:
2057 case SpvOpAtomicSMin:
2058 case SpvOpAtomicUMin:
2059 case SpvOpAtomicSMax:
2060 case SpvOpAtomicUMax:
2061 case SpvOpAtomicAnd:
2062 case SpvOpAtomicOr:
2063 case SpvOpAtomicXor:
2064 case SpvOpAtomicFAddEXT: // Capability AtomicFloat32AddEXT,
2065 // AtomicFloat64AddEXT.
2066 assert(use_in_operand_index != 0 &&
2067 "Signedness check should not occur on a pointer operand.");
2068 return use_in_operand_index == 1 || use_in_operand_index == 2;
2069
2070 case SpvOpAtomicCompareExchange:
2071 case SpvOpAtomicCompareExchangeWeak: // Capability Kernel.
2072 assert(use_in_operand_index != 0 &&
2073 "Signedness check should not occur on a pointer operand.");
2074 return use_in_operand_index >= 1 && use_in_operand_index <= 3;
2075
2076 case SpvOpAtomicLoad:
2077 case SpvOpAtomicIIncrement:
2078 case SpvOpAtomicIDecrement:
2079 case SpvOpAtomicFlagTestAndSet: // Capability Kernel.
2080 case SpvOpAtomicFlagClear: // Capability Kernel.
2081 assert(use_in_operand_index != 0 &&
2082 "Signedness check should not occur on a pointer operand.");
2083 return use_in_operand_index >= 1;
2084
2085 case SpvOpAccessChain:
2086 // The signedness of indices does not matter.
2087 return use_in_operand_index > 0;
2088
2089 default:
2090 // Conservatively assume that the id cannot be swapped in other
2091 // instructions.
2092 return false;
2093 }
2094 }
2095
TypesAreCompatible(opt::IRContext * ir_context,SpvOp opcode,uint32_t use_in_operand_index,uint32_t type_id_1,uint32_t type_id_2)2096 bool TypesAreCompatible(opt::IRContext* ir_context, SpvOp opcode,
2097 uint32_t use_in_operand_index, uint32_t type_id_1,
2098 uint32_t type_id_2) {
2099 assert(ir_context->get_type_mgr()->GetType(type_id_1) &&
2100 ir_context->get_type_mgr()->GetType(type_id_2) &&
2101 "Type ids are invalid");
2102
2103 return type_id_1 == type_id_2 ||
2104 (IsAgnosticToSignednessOfOperand(opcode, use_in_operand_index) &&
2105 fuzzerutil::TypesAreEqualUpToSign(ir_context, type_id_1, type_id_2));
2106 }
2107
2108 } // namespace fuzzerutil
2109 } // namespace fuzz
2110 } // namespace spvtools
2111