1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/fuzzer_util.h"
16
17 #include <algorithm>
18 #include <unordered_set>
19
20 #include "source/opt/build_module.h"
21
22 namespace spvtools {
23 namespace fuzz {
24
25 namespace fuzzerutil {
26 namespace {
27
28 // A utility class that uses RAII to change and restore the terminator
29 // instruction of the |block|.
30 class ChangeTerminatorRAII {
31 public:
ChangeTerminatorRAII(opt::BasicBlock * block,opt::Instruction new_terminator)32 explicit ChangeTerminatorRAII(opt::BasicBlock* block,
33 opt::Instruction new_terminator)
34 : block_(block), old_terminator_(std::move(*block->terminator())) {
35 *block_->terminator() = std::move(new_terminator);
36 }
37
~ChangeTerminatorRAII()38 ~ChangeTerminatorRAII() {
39 *block_->terminator() = std::move(old_terminator_);
40 }
41
42 private:
43 opt::BasicBlock* block_;
44 opt::Instruction old_terminator_;
45 };
46
MaybeGetOpConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t type_id,bool is_irrelevant)47 uint32_t MaybeGetOpConstant(opt::IRContext* ir_context,
48 const TransformationContext& transformation_context,
49 const std::vector<uint32_t>& words,
50 uint32_t type_id, bool is_irrelevant) {
51 for (const auto& inst : ir_context->types_values()) {
52 if (inst.opcode() == SpvOpConstant && inst.type_id() == type_id &&
53 inst.GetInOperand(0).words == words &&
54 transformation_context.GetFactManager()->IdIsIrrelevant(
55 inst.result_id()) == is_irrelevant) {
56 return inst.result_id();
57 }
58 }
59
60 return 0;
61 }
62
63 } // namespace
64
65 const spvtools::MessageConsumer kSilentMessageConsumer =
66 [](spv_message_level_t, const char*, const spv_position_t&,
__anonbc26db0a0202(spv_message_level_t, const char*, const spv_position_t&, const char*) 67 const char*) -> void {};
68
BuildIRContext(spv_target_env target_env,const spvtools::MessageConsumer & message_consumer,const std::vector<uint32_t> & binary_in,spv_validator_options validator_options,std::unique_ptr<spvtools::opt::IRContext> * ir_context)69 bool BuildIRContext(spv_target_env target_env,
70 const spvtools::MessageConsumer& message_consumer,
71 const std::vector<uint32_t>& binary_in,
72 spv_validator_options validator_options,
73 std::unique_ptr<spvtools::opt::IRContext>* ir_context) {
74 SpirvTools tools(target_env);
75 tools.SetMessageConsumer(message_consumer);
76 if (!tools.IsValid()) {
77 message_consumer(SPV_MSG_ERROR, nullptr, {},
78 "Failed to create SPIRV-Tools interface; stopping.");
79 return false;
80 }
81
82 // Initial binary should be valid.
83 if (!tools.Validate(binary_in.data(), binary_in.size(), validator_options)) {
84 message_consumer(SPV_MSG_ERROR, nullptr, {},
85 "Initial binary is invalid; stopping.");
86 return false;
87 }
88
89 // Build the module from the input binary.
90 auto result = BuildModule(target_env, message_consumer, binary_in.data(),
91 binary_in.size());
92 assert(result && "IRContext must be valid");
93 *ir_context = std::move(result);
94 return true;
95 }
96
IsFreshId(opt::IRContext * context,uint32_t id)97 bool IsFreshId(opt::IRContext* context, uint32_t id) {
98 return !context->get_def_use_mgr()->GetDef(id);
99 }
100
UpdateModuleIdBound(opt::IRContext * context,uint32_t id)101 void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
102 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
103 // case where the maximum id bound is reached.
104 context->module()->SetIdBound(
105 std::max(context->module()->id_bound(), id + 1));
106 }
107
MaybeFindBlock(opt::IRContext * context,uint32_t maybe_block_id)108 opt::BasicBlock* MaybeFindBlock(opt::IRContext* context,
109 uint32_t maybe_block_id) {
110 auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
111 if (inst == nullptr) {
112 // No instruction defining this id was found.
113 return nullptr;
114 }
115 if (inst->opcode() != SpvOpLabel) {
116 // The instruction defining the id is not a label, so it cannot be a block
117 // id.
118 return nullptr;
119 }
120 return context->cfg()->block(maybe_block_id);
121 }
122
PhiIdsOkForNewEdge(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)123 bool PhiIdsOkForNewEdge(
124 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
125 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
126 if (bb_from->IsSuccessor(bb_to)) {
127 // There is already an edge from |from_block| to |to_block|, so there is
128 // no need to extend OpPhi instructions. Do not allow phi ids to be
129 // present. This might turn out to be too strict; perhaps it would be OK
130 // just to ignore the ids in this case.
131 return phi_ids.empty();
132 }
133 // The edge would add a previously non-existent edge from |from_block| to
134 // |to_block|, so we go through the given phi ids and check that they exactly
135 // match the OpPhi instructions in |to_block|.
136 uint32_t phi_index = 0;
137 // An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
138 // makes sense here because we need to increment |phi_index| for each OpPhi
139 // instruction.
140 for (auto& inst : *bb_to) {
141 if (inst.opcode() != SpvOpPhi) {
142 // The OpPhi instructions all occur at the start of the block; if we find
143 // a non-OpPhi then we have seen them all.
144 break;
145 }
146 if (phi_index == static_cast<uint32_t>(phi_ids.size())) {
147 // Not enough phi ids have been provided to account for the OpPhi
148 // instructions.
149 return false;
150 }
151 // Look for an instruction defining the next phi id.
152 opt::Instruction* phi_extension =
153 context->get_def_use_mgr()->GetDef(phi_ids[phi_index]);
154 if (!phi_extension) {
155 // The id given to extend this OpPhi does not exist.
156 return false;
157 }
158 if (phi_extension->type_id() != inst.type_id()) {
159 // The instruction given to extend this OpPhi either does not have a type
160 // or its type does not match that of the OpPhi.
161 return false;
162 }
163
164 if (context->get_instr_block(phi_extension)) {
165 // The instruction defining the phi id has an associated block (i.e., it
166 // is not a global value). Check whether its definition dominates the
167 // exit of |from_block|.
168 auto dominator_analysis =
169 context->GetDominatorAnalysis(bb_from->GetParent());
170 if (!dominator_analysis->Dominates(phi_extension,
171 bb_from->terminator())) {
172 // The given id is no good as its definition does not dominate the exit
173 // of |from_block|
174 return false;
175 }
176 }
177 phi_index++;
178 }
179 // We allow some of the ids provided for extending OpPhi instructions to be
180 // unused. Their presence does no harm, and requiring a perfect match may
181 // make transformations less likely to cleanly apply.
182 return true;
183 }
184
CreateUnreachableEdgeInstruction(opt::IRContext * ir_context,uint32_t bb_from_id,uint32_t bb_to_id,uint32_t bool_id)185 opt::Instruction CreateUnreachableEdgeInstruction(opt::IRContext* ir_context,
186 uint32_t bb_from_id,
187 uint32_t bb_to_id,
188 uint32_t bool_id) {
189 const auto* bb_from = MaybeFindBlock(ir_context, bb_from_id);
190 assert(bb_from && "|bb_from_id| is invalid");
191 assert(MaybeFindBlock(ir_context, bb_to_id) && "|bb_to_id| is invalid");
192 assert(bb_from->terminator()->opcode() == SpvOpBranch &&
193 "Precondition on terminator of bb_from is not satisfied");
194
195 // Get the id of the boolean constant to be used as the condition.
196 auto condition_inst = ir_context->get_def_use_mgr()->GetDef(bool_id);
197 assert(condition_inst &&
198 (condition_inst->opcode() == SpvOpConstantTrue ||
199 condition_inst->opcode() == SpvOpConstantFalse) &&
200 "|bool_id| is invalid");
201
202 auto condition_value = condition_inst->opcode() == SpvOpConstantTrue;
203 auto successor_id = bb_from->terminator()->GetSingleWordInOperand(0);
204
205 // Add the dead branch, by turning OpBranch into OpBranchConditional, and
206 // ordering the targets depending on whether the given boolean corresponds to
207 // true or false.
208 return opt::Instruction(
209 ir_context, SpvOpBranchConditional, 0, 0,
210 {{SPV_OPERAND_TYPE_ID, {bool_id}},
211 {SPV_OPERAND_TYPE_ID, {condition_value ? successor_id : bb_to_id}},
212 {SPV_OPERAND_TYPE_ID, {condition_value ? bb_to_id : successor_id}}});
213 }
214
AddUnreachableEdgeAndUpdateOpPhis(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,uint32_t bool_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)215 void AddUnreachableEdgeAndUpdateOpPhis(
216 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
217 uint32_t bool_id,
218 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
219 assert(PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) &&
220 "Precondition on phi_ids is not satisfied");
221
222 const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
223 *bb_from->terminator() = CreateUnreachableEdgeInstruction(
224 context, bb_from->id(), bb_to->id(), bool_id);
225
226 // Update OpPhi instructions in the target block if this branch adds a
227 // previously non-existent edge from source to target.
228 if (!from_to_edge_already_exists) {
229 uint32_t phi_index = 0;
230 for (auto& inst : *bb_to) {
231 if (inst.opcode() != SpvOpPhi) {
232 break;
233 }
234 assert(phi_index < static_cast<uint32_t>(phi_ids.size()) &&
235 "There should be at least one phi id per OpPhi instruction.");
236 inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}});
237 inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}});
238 phi_index++;
239 }
240 }
241 }
242
BlockIsBackEdge(opt::IRContext * context,uint32_t block_id,uint32_t loop_header_id)243 bool BlockIsBackEdge(opt::IRContext* context, uint32_t block_id,
244 uint32_t loop_header_id) {
245 auto block = context->cfg()->block(block_id);
246 auto loop_header = context->cfg()->block(loop_header_id);
247
248 // |block| and |loop_header| must be defined, |loop_header| must be in fact
249 // loop header and |block| must branch to it.
250 if (!(block && loop_header && loop_header->IsLoopHeader() &&
251 block->IsSuccessor(loop_header))) {
252 return false;
253 }
254
255 // |block| must be reachable and be dominated by |loop_header|.
256 opt::DominatorAnalysis* dominator_analysis =
257 context->GetDominatorAnalysis(loop_header->GetParent());
258 return context->IsReachable(*block) &&
259 dominator_analysis->Dominates(loop_header, block);
260 }
261
BlockIsInLoopContinueConstruct(opt::IRContext * context,uint32_t block_id,uint32_t maybe_loop_header_id)262 bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
263 uint32_t maybe_loop_header_id) {
264 // We deem a block to be part of a loop's continue construct if the loop's
265 // continue target dominates the block.
266 auto containing_construct_block = context->cfg()->block(maybe_loop_header_id);
267 if (containing_construct_block->IsLoopHeader()) {
268 auto continue_target = containing_construct_block->ContinueBlockId();
269 if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
270 ->Dominates(continue_target, block_id)) {
271 return true;
272 }
273 }
274 return false;
275 }
276
GetIteratorForInstruction(opt::BasicBlock * block,const opt::Instruction * inst)277 opt::BasicBlock::iterator GetIteratorForInstruction(
278 opt::BasicBlock* block, const opt::Instruction* inst) {
279 for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
280 if (inst == &*inst_it) {
281 return inst_it;
282 }
283 }
284 return block->end();
285 }
286
CanInsertOpcodeBeforeInstruction(SpvOp opcode,const opt::BasicBlock::iterator & instruction_in_block)287 bool CanInsertOpcodeBeforeInstruction(
288 SpvOp opcode, const opt::BasicBlock::iterator& instruction_in_block) {
289 if (instruction_in_block->PreviousNode() &&
290 (instruction_in_block->PreviousNode()->opcode() == SpvOpLoopMerge ||
291 instruction_in_block->PreviousNode()->opcode() == SpvOpSelectionMerge)) {
292 // We cannot insert directly after a merge instruction.
293 return false;
294 }
295 if (opcode != SpvOpVariable &&
296 instruction_in_block->opcode() == SpvOpVariable) {
297 // We cannot insert a non-OpVariable instruction directly before a
298 // variable; variables in a function must be contiguous in the entry block.
299 return false;
300 }
301 // We cannot insert a non-OpPhi instruction directly before an OpPhi, because
302 // OpPhi instructions need to be contiguous at the start of a block.
303 return opcode == SpvOpPhi || instruction_in_block->opcode() != SpvOpPhi;
304 }
305
CanMakeSynonymOf(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * inst)306 bool CanMakeSynonymOf(opt::IRContext* ir_context,
307 const TransformationContext& transformation_context,
308 opt::Instruction* inst) {
309 if (inst->opcode() == SpvOpSampledImage) {
310 // The SPIR-V data rules say that only very specific instructions may
311 // may consume the result id of an OpSampledImage, and this excludes the
312 // instructions that are used for making synonyms.
313 return false;
314 }
315 if (!inst->HasResultId()) {
316 // We can only make a synonym of an instruction that generates an id.
317 return false;
318 }
319 if (transformation_context.GetFactManager()->IdIsIrrelevant(
320 inst->result_id())) {
321 // An irrelevant id can't be a synonym of anything.
322 return false;
323 }
324 if (!inst->type_id()) {
325 // We can only make a synonym of an instruction that has a type.
326 return false;
327 }
328 auto type_inst = ir_context->get_def_use_mgr()->GetDef(inst->type_id());
329 if (type_inst->opcode() == SpvOpTypeVoid) {
330 // We only make synonyms of instructions that define objects, and an object
331 // cannot have void type.
332 return false;
333 }
334 if (type_inst->opcode() == SpvOpTypePointer) {
335 switch (inst->opcode()) {
336 case SpvOpConstantNull:
337 case SpvOpUndef:
338 // We disallow making synonyms of null or undefined pointers. This is
339 // to provide the property that if the original shader exhibited no bad
340 // pointer accesses, the transformed shader will not either.
341 return false;
342 default:
343 break;
344 }
345 }
346
347 // We do not make synonyms of objects that have decorations: if the synonym is
348 // not decorated analogously, using the original object vs. its synonymous
349 // form may not be equivalent.
350 return ir_context->get_decoration_mgr()
351 ->GetDecorationsFor(inst->result_id(), true)
352 .empty();
353 }
354
IsCompositeType(const opt::analysis::Type * type)355 bool IsCompositeType(const opt::analysis::Type* type) {
356 return type && (type->AsArray() || type->AsMatrix() || type->AsStruct() ||
357 type->AsVector());
358 }
359
RepeatedFieldToVector(const google::protobuf::RepeatedField<uint32_t> & repeated_field)360 std::vector<uint32_t> RepeatedFieldToVector(
361 const google::protobuf::RepeatedField<uint32_t>& repeated_field) {
362 std::vector<uint32_t> result;
363 for (auto i : repeated_field) {
364 result.push_back(i);
365 }
366 return result;
367 }
368
WalkOneCompositeTypeIndex(opt::IRContext * context,uint32_t base_object_type_id,uint32_t index)369 uint32_t WalkOneCompositeTypeIndex(opt::IRContext* context,
370 uint32_t base_object_type_id,
371 uint32_t index) {
372 auto should_be_composite_type =
373 context->get_def_use_mgr()->GetDef(base_object_type_id);
374 assert(should_be_composite_type && "The type should exist.");
375 switch (should_be_composite_type->opcode()) {
376 case SpvOpTypeArray: {
377 auto array_length = GetArraySize(*should_be_composite_type, context);
378 if (array_length == 0 || index >= array_length) {
379 return 0;
380 }
381 return should_be_composite_type->GetSingleWordInOperand(0);
382 }
383 case SpvOpTypeMatrix:
384 case SpvOpTypeVector: {
385 auto count = should_be_composite_type->GetSingleWordInOperand(1);
386 if (index >= count) {
387 return 0;
388 }
389 return should_be_composite_type->GetSingleWordInOperand(0);
390 }
391 case SpvOpTypeStruct: {
392 if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
393 return 0;
394 }
395 return should_be_composite_type->GetSingleWordInOperand(index);
396 }
397 default:
398 return 0;
399 }
400 }
401
WalkCompositeTypeIndices(opt::IRContext * context,uint32_t base_object_type_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & indices)402 uint32_t WalkCompositeTypeIndices(
403 opt::IRContext* context, uint32_t base_object_type_id,
404 const google::protobuf::RepeatedField<google::protobuf::uint32>& indices) {
405 uint32_t sub_object_type_id = base_object_type_id;
406 for (auto index : indices) {
407 sub_object_type_id =
408 WalkOneCompositeTypeIndex(context, sub_object_type_id, index);
409 if (!sub_object_type_id) {
410 return 0;
411 }
412 }
413 return sub_object_type_id;
414 }
415
GetNumberOfStructMembers(const opt::Instruction & struct_type_instruction)416 uint32_t GetNumberOfStructMembers(
417 const opt::Instruction& struct_type_instruction) {
418 assert(struct_type_instruction.opcode() == SpvOpTypeStruct &&
419 "An OpTypeStruct instruction is required here.");
420 return struct_type_instruction.NumInOperands();
421 }
422
GetArraySize(const opt::Instruction & array_type_instruction,opt::IRContext * context)423 uint32_t GetArraySize(const opt::Instruction& array_type_instruction,
424 opt::IRContext* context) {
425 auto array_length_constant =
426 context->get_constant_mgr()
427 ->GetConstantFromInst(context->get_def_use_mgr()->GetDef(
428 array_type_instruction.GetSingleWordInOperand(1)))
429 ->AsIntConstant();
430 if (array_length_constant->words().size() != 1) {
431 return 0;
432 }
433 return array_length_constant->GetU32();
434 }
435
GetBoundForCompositeIndex(const opt::Instruction & composite_type_inst,opt::IRContext * ir_context)436 uint32_t GetBoundForCompositeIndex(const opt::Instruction& composite_type_inst,
437 opt::IRContext* ir_context) {
438 switch (composite_type_inst.opcode()) {
439 case SpvOpTypeArray:
440 return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
441 case SpvOpTypeMatrix:
442 case SpvOpTypeVector:
443 return composite_type_inst.GetSingleWordInOperand(1);
444 case SpvOpTypeStruct: {
445 return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
446 }
447 case SpvOpTypeRuntimeArray:
448 assert(false &&
449 "GetBoundForCompositeIndex should not be invoked with an "
450 "OpTypeRuntimeArray, which does not have a static bound.");
451 return 0;
452 default:
453 assert(false && "Unknown composite type.");
454 return 0;
455 }
456 }
457
IsValid(const opt::IRContext * context,spv_validator_options validator_options,MessageConsumer consumer)458 bool IsValid(const opt::IRContext* context,
459 spv_validator_options validator_options,
460 MessageConsumer consumer) {
461 std::vector<uint32_t> binary;
462 context->module()->ToBinary(&binary, false);
463 SpirvTools tools(context->grammar().target_env());
464 tools.SetMessageConsumer(std::move(consumer));
465 return tools.Validate(binary.data(), binary.size(), validator_options);
466 }
467
IsValidAndWellFormed(const opt::IRContext * ir_context,spv_validator_options validator_options,MessageConsumer consumer)468 bool IsValidAndWellFormed(const opt::IRContext* ir_context,
469 spv_validator_options validator_options,
470 MessageConsumer consumer) {
471 if (!IsValid(ir_context, validator_options, consumer)) {
472 // Expression to dump |ir_context| to /data/temp/shader.spv:
473 // DumpShader(ir_context, "/data/temp/shader.spv")
474 consumer(SPV_MSG_INFO, nullptr, {},
475 "Module is invalid (set a breakpoint to inspect).");
476 return false;
477 }
478 // Check that all blocks in the module have appropriate parent functions.
479 for (auto& function : *ir_context->module()) {
480 for (auto& block : function) {
481 if (block.GetParent() == nullptr) {
482 std::stringstream ss;
483 ss << "Block " << block.id() << " has no parent; its parent should be "
484 << function.result_id() << " (set a breakpoint to inspect).";
485 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
486 return false;
487 }
488 if (block.GetParent() != &function) {
489 std::stringstream ss;
490 ss << "Block " << block.id() << " should have parent "
491 << function.result_id() << " but instead has parent "
492 << block.GetParent() << " (set a breakpoint to inspect).";
493 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
494 return false;
495 }
496 }
497 }
498
499 // Check that all instructions have distinct unique ids. We map each unique
500 // id to the first instruction it is observed to be associated with so that
501 // if we encounter a duplicate we have access to the previous instruction -
502 // this is a useful aid to debugging.
503 std::unordered_map<uint32_t, opt::Instruction*> unique_ids;
504 bool found_duplicate = false;
505 ir_context->module()->ForEachInst([&consumer, &found_duplicate,
506 &unique_ids](opt::Instruction* inst) {
507 if (unique_ids.count(inst->unique_id()) != 0) {
508 consumer(SPV_MSG_INFO, nullptr, {},
509 "Two instructions have the same unique id (set a breakpoint to "
510 "inspect).");
511 found_duplicate = true;
512 }
513 unique_ids.insert({inst->unique_id(), inst});
514 });
515 return !found_duplicate;
516 }
517
CloneIRContext(opt::IRContext * context)518 std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) {
519 std::vector<uint32_t> binary;
520 context->module()->ToBinary(&binary, false);
521 return BuildModule(context->grammar().target_env(), nullptr, binary.data(),
522 binary.size());
523 }
524
IsNonFunctionTypeId(opt::IRContext * ir_context,uint32_t id)525 bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id) {
526 auto type = ir_context->get_type_mgr()->GetType(id);
527 return type && !type->AsFunction();
528 }
529
IsMergeOrContinue(opt::IRContext * ir_context,uint32_t block_id)530 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) {
531 bool result = false;
532 ir_context->get_def_use_mgr()->WhileEachUse(
533 block_id,
534 [&result](const opt::Instruction* use_instruction,
535 uint32_t /*unused*/) -> bool {
536 switch (use_instruction->opcode()) {
537 case SpvOpLoopMerge:
538 case SpvOpSelectionMerge:
539 result = true;
540 return false;
541 default:
542 return true;
543 }
544 });
545 return result;
546 }
547
GetLoopFromMergeBlock(opt::IRContext * ir_context,uint32_t merge_block_id)548 uint32_t GetLoopFromMergeBlock(opt::IRContext* ir_context,
549 uint32_t merge_block_id) {
550 uint32_t result = 0;
551 ir_context->get_def_use_mgr()->WhileEachUse(
552 merge_block_id,
553 [ir_context, &result](opt::Instruction* use_instruction,
554 uint32_t use_index) -> bool {
555 switch (use_instruction->opcode()) {
556 case SpvOpLoopMerge:
557 // The merge block operand is the first operand in OpLoopMerge.
558 if (use_index == 0) {
559 result = ir_context->get_instr_block(use_instruction)->id();
560 return false;
561 }
562 return true;
563 default:
564 return true;
565 }
566 });
567 return result;
568 }
569
FindFunctionType(opt::IRContext * ir_context,const std::vector<uint32_t> & type_ids)570 uint32_t FindFunctionType(opt::IRContext* ir_context,
571 const std::vector<uint32_t>& type_ids) {
572 // Look through the existing types for a match.
573 for (auto& type_or_value : ir_context->types_values()) {
574 if (type_or_value.opcode() != SpvOpTypeFunction) {
575 // We are only interested in function types.
576 continue;
577 }
578 if (type_or_value.NumInOperands() != type_ids.size()) {
579 // Not a match: different numbers of arguments.
580 continue;
581 }
582 // Check whether the return type and argument types match.
583 bool input_operands_match = true;
584 for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
585 if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
586 input_operands_match = false;
587 break;
588 }
589 }
590 if (input_operands_match) {
591 // Everything matches.
592 return type_or_value.result_id();
593 }
594 }
595 // No match was found.
596 return 0;
597 }
598
GetFunctionType(opt::IRContext * context,const opt::Function * function)599 opt::Instruction* GetFunctionType(opt::IRContext* context,
600 const opt::Function* function) {
601 uint32_t type_id = function->DefInst().GetSingleWordInOperand(1);
602 return context->get_def_use_mgr()->GetDef(type_id);
603 }
604
FindFunction(opt::IRContext * ir_context,uint32_t function_id)605 opt::Function* FindFunction(opt::IRContext* ir_context, uint32_t function_id) {
606 for (auto& function : *ir_context->module()) {
607 if (function.result_id() == function_id) {
608 return &function;
609 }
610 }
611 return nullptr;
612 }
613
FunctionContainsOpKillOrUnreachable(const opt::Function & function)614 bool FunctionContainsOpKillOrUnreachable(const opt::Function& function) {
615 for (auto& block : function) {
616 if (block.terminator()->opcode() == SpvOpKill ||
617 block.terminator()->opcode() == SpvOpUnreachable) {
618 return true;
619 }
620 }
621 return false;
622 }
623
FunctionIsEntryPoint(opt::IRContext * context,uint32_t function_id)624 bool FunctionIsEntryPoint(opt::IRContext* context, uint32_t function_id) {
625 for (auto& entry_point : context->module()->entry_points()) {
626 if (entry_point.GetSingleWordInOperand(1) == function_id) {
627 return true;
628 }
629 }
630 return false;
631 }
632
IdIsAvailableAtUse(opt::IRContext * context,opt::Instruction * use_instruction,uint32_t use_input_operand_index,uint32_t id)633 bool IdIsAvailableAtUse(opt::IRContext* context,
634 opt::Instruction* use_instruction,
635 uint32_t use_input_operand_index, uint32_t id) {
636 assert(context->get_instr_block(use_instruction) &&
637 "|use_instruction| must be in a basic block");
638
639 auto defining_instruction = context->get_def_use_mgr()->GetDef(id);
640 auto enclosing_function =
641 context->get_instr_block(use_instruction)->GetParent();
642 // If the id a function parameter, it needs to be associated with the
643 // function containing the use.
644 if (defining_instruction->opcode() == SpvOpFunctionParameter) {
645 return InstructionIsFunctionParameter(defining_instruction,
646 enclosing_function);
647 }
648 if (!context->get_instr_block(id)) {
649 // The id must be at global scope.
650 return true;
651 }
652 if (defining_instruction == use_instruction) {
653 // It is not OK for a definition to use itself.
654 return false;
655 }
656 if (!context->IsReachable(*context->get_instr_block(use_instruction)) ||
657 !context->IsReachable(*context->get_instr_block(id))) {
658 // Skip unreachable blocks.
659 return false;
660 }
661 auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function);
662 if (use_instruction->opcode() == SpvOpPhi) {
663 // In the case where the use is an operand to OpPhi, it is actually the
664 // *parent* block associated with the operand that must be dominated by
665 // the synonym.
666 auto parent_block =
667 use_instruction->GetSingleWordInOperand(use_input_operand_index + 1);
668 return dominator_analysis->Dominates(
669 context->get_instr_block(defining_instruction)->id(), parent_block);
670 }
671 return dominator_analysis->Dominates(defining_instruction, use_instruction);
672 }
673
IdIsAvailableBeforeInstruction(opt::IRContext * context,opt::Instruction * instruction,uint32_t id)674 bool IdIsAvailableBeforeInstruction(opt::IRContext* context,
675 opt::Instruction* instruction,
676 uint32_t id) {
677 assert(context->get_instr_block(instruction) &&
678 "|instruction| must be in a basic block");
679
680 auto id_definition = context->get_def_use_mgr()->GetDef(id);
681 auto function_enclosing_instruction =
682 context->get_instr_block(instruction)->GetParent();
683 // If the id a function parameter, it needs to be associated with the
684 // function containing the instruction.
685 if (id_definition->opcode() == SpvOpFunctionParameter) {
686 return InstructionIsFunctionParameter(id_definition,
687 function_enclosing_instruction);
688 }
689 if (!context->get_instr_block(id)) {
690 // The id is at global scope.
691 return true;
692 }
693 if (id_definition == instruction) {
694 // The instruction is not available right before its own definition.
695 return false;
696 }
697 const auto* dominator_analysis =
698 context->GetDominatorAnalysis(function_enclosing_instruction);
699 if (context->IsReachable(*context->get_instr_block(instruction)) &&
700 context->IsReachable(*context->get_instr_block(id)) &&
701 dominator_analysis->Dominates(id_definition, instruction)) {
702 // The id's definition dominates the instruction, and both the definition
703 // and the instruction are in reachable blocks, thus the id is available at
704 // the instruction.
705 return true;
706 }
707 if (id_definition->opcode() == SpvOpVariable &&
708 function_enclosing_instruction ==
709 context->get_instr_block(id)->GetParent()) {
710 assert(!context->IsReachable(*context->get_instr_block(instruction)) &&
711 "If the instruction were in a reachable block we should already "
712 "have returned true.");
713 // The id is a variable and it is in the same function as |instruction|.
714 // This is OK despite |instruction| being unreachable.
715 return true;
716 }
717 return false;
718 }
719
InstructionIsFunctionParameter(opt::Instruction * instruction,opt::Function * function)720 bool InstructionIsFunctionParameter(opt::Instruction* instruction,
721 opt::Function* function) {
722 if (instruction->opcode() != SpvOpFunctionParameter) {
723 return false;
724 }
725 bool found_parameter = false;
726 function->ForEachParam(
727 [instruction, &found_parameter](opt::Instruction* param) {
728 if (param == instruction) {
729 found_parameter = true;
730 }
731 });
732 return found_parameter;
733 }
734
GetTypeId(opt::IRContext * context,uint32_t result_id)735 uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
736 const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
737 assert(inst && "|result_id| is invalid");
738 return inst->type_id();
739 }
740
GetPointeeTypeIdFromPointerType(opt::Instruction * pointer_type_inst)741 uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
742 assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
743 "Precondition: |pointer_type_inst| must be OpTypePointer.");
744 return pointer_type_inst->GetSingleWordInOperand(1);
745 }
746
GetPointeeTypeIdFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)747 uint32_t GetPointeeTypeIdFromPointerType(opt::IRContext* context,
748 uint32_t pointer_type_id) {
749 return GetPointeeTypeIdFromPointerType(
750 context->get_def_use_mgr()->GetDef(pointer_type_id));
751 }
752
GetStorageClassFromPointerType(opt::Instruction * pointer_type_inst)753 SpvStorageClass GetStorageClassFromPointerType(
754 opt::Instruction* pointer_type_inst) {
755 assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
756 "Precondition: |pointer_type_inst| must be OpTypePointer.");
757 return static_cast<SpvStorageClass>(
758 pointer_type_inst->GetSingleWordInOperand(0));
759 }
760
GetStorageClassFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)761 SpvStorageClass GetStorageClassFromPointerType(opt::IRContext* context,
762 uint32_t pointer_type_id) {
763 return GetStorageClassFromPointerType(
764 context->get_def_use_mgr()->GetDef(pointer_type_id));
765 }
766
MaybeGetPointerType(opt::IRContext * context,uint32_t pointee_type_id,SpvStorageClass storage_class)767 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
768 SpvStorageClass storage_class) {
769 for (auto& inst : context->types_values()) {
770 switch (inst.opcode()) {
771 case SpvOpTypePointer:
772 if (inst.GetSingleWordInOperand(0) == storage_class &&
773 inst.GetSingleWordInOperand(1) == pointee_type_id) {
774 return inst.result_id();
775 }
776 break;
777 default:
778 break;
779 }
780 }
781 return 0;
782 }
783
InOperandIndexFromOperandIndex(const opt::Instruction & inst,uint32_t absolute_index)784 uint32_t InOperandIndexFromOperandIndex(const opt::Instruction& inst,
785 uint32_t absolute_index) {
786 // Subtract the number of non-input operands from the index
787 return absolute_index - inst.NumOperands() + inst.NumInOperands();
788 }
789
IsNullConstantSupported(const opt::analysis::Type & type)790 bool IsNullConstantSupported(const opt::analysis::Type& type) {
791 return type.AsBool() || type.AsInteger() || type.AsFloat() ||
792 type.AsMatrix() || type.AsVector() || type.AsArray() ||
793 type.AsStruct() || type.AsPointer() || type.AsEvent() ||
794 type.AsDeviceEvent() || type.AsReserveId() || type.AsQueue();
795 }
796
GlobalVariablesMustBeDeclaredInEntryPointInterfaces(const opt::IRContext * ir_context)797 bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
798 const opt::IRContext* ir_context) {
799 // TODO(afd): We capture the environments for which this requirement holds.
800 // The check should be refined on demand for other target environments.
801 switch (ir_context->grammar().target_env()) {
802 case SPV_ENV_UNIVERSAL_1_0:
803 case SPV_ENV_UNIVERSAL_1_1:
804 case SPV_ENV_UNIVERSAL_1_2:
805 case SPV_ENV_UNIVERSAL_1_3:
806 case SPV_ENV_VULKAN_1_0:
807 case SPV_ENV_VULKAN_1_1:
808 return false;
809 default:
810 return true;
811 }
812 }
813
AddVariableIdToEntryPointInterfaces(opt::IRContext * context,uint32_t id)814 void AddVariableIdToEntryPointInterfaces(opt::IRContext* context, uint32_t id) {
815 if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(context)) {
816 // Conservatively add this global to the interface of every entry point in
817 // the module. This means that the global is available for other
818 // transformations to use.
819 //
820 // A downside of this is that the global will be in the interface even if it
821 // ends up never being used.
822 //
823 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit
824 // this if a more thorough approach to entry point interfaces is taken.
825 for (auto& entry_point : context->module()->entry_points()) {
826 entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {id}});
827 }
828 }
829 }
830
AddGlobalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,SpvStorageClass storage_class,uint32_t initializer_id)831 opt::Instruction* AddGlobalVariable(opt::IRContext* context, uint32_t result_id,
832 uint32_t type_id,
833 SpvStorageClass storage_class,
834 uint32_t initializer_id) {
835 // Check various preconditions.
836 assert(result_id != 0 && "Result id can't be 0");
837
838 assert((storage_class == SpvStorageClassPrivate ||
839 storage_class == SpvStorageClassWorkgroup) &&
840 "Variable's storage class must be either Private or Workgroup");
841
842 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
843 (void)type_inst; // Variable becomes unused in release mode.
844 assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
845 GetStorageClassFromPointerType(type_inst) == storage_class &&
846 "Variable's type is invalid");
847
848 if (storage_class == SpvStorageClassWorkgroup) {
849 assert(initializer_id == 0);
850 }
851
852 if (initializer_id != 0) {
853 const auto* constant_inst =
854 context->get_def_use_mgr()->GetDef(initializer_id);
855 (void)constant_inst; // Variable becomes unused in release mode.
856 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
857 GetPointeeTypeIdFromPointerType(type_inst) ==
858 constant_inst->type_id() &&
859 "Initializer is invalid");
860 }
861
862 opt::Instruction::OperandList operands = {
863 {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast<uint32_t>(storage_class)}}};
864
865 if (initializer_id) {
866 operands.push_back({SPV_OPERAND_TYPE_ID, {initializer_id}});
867 }
868
869 auto new_instruction = MakeUnique<opt::Instruction>(
870 context, SpvOpVariable, type_id, result_id, std::move(operands));
871 auto result = new_instruction.get();
872 context->module()->AddGlobalValue(std::move(new_instruction));
873
874 AddVariableIdToEntryPointInterfaces(context, result_id);
875 UpdateModuleIdBound(context, result_id);
876
877 return result;
878 }
879
AddLocalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,uint32_t function_id,uint32_t initializer_id)880 opt::Instruction* AddLocalVariable(opt::IRContext* context, uint32_t result_id,
881 uint32_t type_id, uint32_t function_id,
882 uint32_t initializer_id) {
883 // Check various preconditions.
884 assert(result_id != 0 && "Result id can't be 0");
885
886 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
887 (void)type_inst; // Variable becomes unused in release mode.
888 assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
889 GetStorageClassFromPointerType(type_inst) == SpvStorageClassFunction &&
890 "Variable's type is invalid");
891
892 const auto* constant_inst =
893 context->get_def_use_mgr()->GetDef(initializer_id);
894 (void)constant_inst; // Variable becomes unused in release mode.
895 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
896 GetPointeeTypeIdFromPointerType(type_inst) ==
897 constant_inst->type_id() &&
898 "Initializer is invalid");
899
900 auto* function = FindFunction(context, function_id);
901 assert(function && "Function id is invalid");
902
903 auto new_instruction = MakeUnique<opt::Instruction>(
904 context, SpvOpVariable, type_id, result_id,
905 opt::Instruction::OperandList{
906 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
907 {SPV_OPERAND_TYPE_ID, {initializer_id}}});
908 auto result = new_instruction.get();
909 function->begin()->begin()->InsertBefore(std::move(new_instruction));
910
911 UpdateModuleIdBound(context, result_id);
912
913 return result;
914 }
915
HasDuplicates(const std::vector<uint32_t> & arr)916 bool HasDuplicates(const std::vector<uint32_t>& arr) {
917 return std::unordered_set<uint32_t>(arr.begin(), arr.end()).size() !=
918 arr.size();
919 }
920
IsPermutationOfRange(const std::vector<uint32_t> & arr,uint32_t lo,uint32_t hi)921 bool IsPermutationOfRange(const std::vector<uint32_t>& arr, uint32_t lo,
922 uint32_t hi) {
923 if (arr.empty()) {
924 return lo > hi;
925 }
926
927 if (HasDuplicates(arr)) {
928 return false;
929 }
930
931 auto min_max = std::minmax_element(arr.begin(), arr.end());
932 return arr.size() == hi - lo + 1 && *min_max.first == lo &&
933 *min_max.second == hi;
934 }
935
GetParameters(opt::IRContext * ir_context,uint32_t function_id)936 std::vector<opt::Instruction*> GetParameters(opt::IRContext* ir_context,
937 uint32_t function_id) {
938 auto* function = FindFunction(ir_context, function_id);
939 assert(function && "|function_id| is invalid");
940
941 std::vector<opt::Instruction*> result;
942 function->ForEachParam(
943 [&result](opt::Instruction* inst) { result.push_back(inst); });
944
945 return result;
946 }
947
RemoveParameter(opt::IRContext * ir_context,uint32_t parameter_id)948 void RemoveParameter(opt::IRContext* ir_context, uint32_t parameter_id) {
949 auto* function = GetFunctionFromParameterId(ir_context, parameter_id);
950 assert(function && "|parameter_id| is invalid");
951 assert(!FunctionIsEntryPoint(ir_context, function->result_id()) &&
952 "Can't remove parameter from an entry point function");
953
954 function->RemoveParameter(parameter_id);
955
956 // We've just removed parameters from the function and cleared their memory.
957 // Make sure analyses have no dangling pointers.
958 ir_context->InvalidateAnalysesExceptFor(
959 opt::IRContext::Analysis::kAnalysisNone);
960 }
961
GetCallers(opt::IRContext * ir_context,uint32_t function_id)962 std::vector<opt::Instruction*> GetCallers(opt::IRContext* ir_context,
963 uint32_t function_id) {
964 assert(FindFunction(ir_context, function_id) &&
965 "|function_id| is not a result id of a function");
966
967 std::vector<opt::Instruction*> result;
968 ir_context->get_def_use_mgr()->ForEachUser(
969 function_id, [&result, function_id](opt::Instruction* inst) {
970 if (inst->opcode() == SpvOpFunctionCall &&
971 inst->GetSingleWordInOperand(0) == function_id) {
972 result.push_back(inst);
973 }
974 });
975
976 return result;
977 }
978
GetFunctionFromParameterId(opt::IRContext * ir_context,uint32_t param_id)979 opt::Function* GetFunctionFromParameterId(opt::IRContext* ir_context,
980 uint32_t param_id) {
981 auto* param_inst = ir_context->get_def_use_mgr()->GetDef(param_id);
982 assert(param_inst && "Parameter id is invalid");
983
984 for (auto& function : *ir_context->module()) {
985 if (InstructionIsFunctionParameter(param_inst, &function)) {
986 return &function;
987 }
988 }
989
990 return nullptr;
991 }
992
UpdateFunctionType(opt::IRContext * ir_context,uint32_t function_id,uint32_t new_function_type_result_id,uint32_t return_type_id,const std::vector<uint32_t> & parameter_type_ids)993 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
994 uint32_t new_function_type_result_id,
995 uint32_t return_type_id,
996 const std::vector<uint32_t>& parameter_type_ids) {
997 // Check some initial constraints.
998 assert(ir_context->get_type_mgr()->GetType(return_type_id) &&
999 "Return type is invalid");
1000 for (auto id : parameter_type_ids) {
1001 const auto* type = ir_context->get_type_mgr()->GetType(id);
1002 (void)type; // Make compilers happy in release mode.
1003 // Parameters can't be OpTypeVoid.
1004 assert(type && !type->AsVoid() && "Parameter has invalid type");
1005 }
1006
1007 auto* function = FindFunction(ir_context, function_id);
1008 assert(function && "|function_id| is invalid");
1009
1010 auto* old_function_type = GetFunctionType(ir_context, function);
1011 assert(old_function_type && "Function has invalid type");
1012
1013 std::vector<uint32_t> operand_ids = {return_type_id};
1014 operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
1015 parameter_type_ids.end());
1016
1017 // A trivial case - we change nothing.
1018 if (FindFunctionType(ir_context, operand_ids) ==
1019 old_function_type->result_id()) {
1020 return old_function_type->result_id();
1021 }
1022
1023 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
1024 FindFunctionType(ir_context, operand_ids) == 0) {
1025 // We can change |old_function_type| only if it's used once in the module
1026 // and we are certain we won't create a duplicate as a result of the change.
1027
1028 // Update |old_function_type| in-place.
1029 opt::Instruction::OperandList operands;
1030 for (auto id : operand_ids) {
1031 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1032 }
1033
1034 old_function_type->SetInOperands(std::move(operands));
1035
1036 // |operands| may depend on result ids defined below the |old_function_type|
1037 // in the module.
1038 old_function_type->RemoveFromList();
1039 ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
1040 return old_function_type->result_id();
1041 } else {
1042 // We can't modify the |old_function_type| so we have to either use an
1043 // existing one or create a new one.
1044 auto type_id = FindOrCreateFunctionType(
1045 ir_context, new_function_type_result_id, operand_ids);
1046 assert(type_id != old_function_type->result_id() &&
1047 "We should've handled this case above");
1048
1049 function->DefInst().SetInOperand(1, {type_id});
1050
1051 // DefUseManager hasn't been updated yet, so if the following condition is
1052 // true, then |old_function_type| will have no users when this function
1053 // returns. We might as well remove it.
1054 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
1055 ir_context->KillInst(old_function_type);
1056 }
1057
1058 return type_id;
1059 }
1060 }
1061
AddFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1062 void AddFunctionType(opt::IRContext* ir_context, uint32_t result_id,
1063 const std::vector<uint32_t>& type_ids) {
1064 assert(result_id != 0 && "Result id can't be 0");
1065 assert(!type_ids.empty() &&
1066 "OpTypeFunction always has at least one operand - function's return "
1067 "type");
1068 assert(IsNonFunctionTypeId(ir_context, type_ids[0]) &&
1069 "Return type must not be a function");
1070
1071 for (size_t i = 1; i < type_ids.size(); ++i) {
1072 const auto* param_type = ir_context->get_type_mgr()->GetType(type_ids[i]);
1073 (void)param_type; // Make compiler happy in release mode.
1074 assert(param_type && !param_type->AsVoid() && !param_type->AsFunction() &&
1075 "Function parameter can't have a function or void type");
1076 }
1077
1078 opt::Instruction::OperandList operands;
1079 operands.reserve(type_ids.size());
1080 for (auto id : type_ids) {
1081 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1082 }
1083
1084 ir_context->AddType(MakeUnique<opt::Instruction>(
1085 ir_context, SpvOpTypeFunction, 0, result_id, std::move(operands)));
1086
1087 UpdateModuleIdBound(ir_context, result_id);
1088 }
1089
FindOrCreateFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1090 uint32_t FindOrCreateFunctionType(opt::IRContext* ir_context,
1091 uint32_t result_id,
1092 const std::vector<uint32_t>& type_ids) {
1093 if (auto existing_id = FindFunctionType(ir_context, type_ids)) {
1094 return existing_id;
1095 }
1096 AddFunctionType(ir_context, result_id, type_ids);
1097 return result_id;
1098 }
1099
MaybeGetIntegerType(opt::IRContext * ir_context,uint32_t width,bool is_signed)1100 uint32_t MaybeGetIntegerType(opt::IRContext* ir_context, uint32_t width,
1101 bool is_signed) {
1102 opt::analysis::Integer type(width, is_signed);
1103 return ir_context->get_type_mgr()->GetId(&type);
1104 }
1105
MaybeGetFloatType(opt::IRContext * ir_context,uint32_t width)1106 uint32_t MaybeGetFloatType(opt::IRContext* ir_context, uint32_t width) {
1107 opt::analysis::Float type(width);
1108 return ir_context->get_type_mgr()->GetId(&type);
1109 }
1110
MaybeGetBoolType(opt::IRContext * ir_context)1111 uint32_t MaybeGetBoolType(opt::IRContext* ir_context) {
1112 opt::analysis::Bool type;
1113 return ir_context->get_type_mgr()->GetId(&type);
1114 }
1115
MaybeGetVectorType(opt::IRContext * ir_context,uint32_t component_type_id,uint32_t element_count)1116 uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
1117 uint32_t component_type_id,
1118 uint32_t element_count) {
1119 const auto* component_type =
1120 ir_context->get_type_mgr()->GetType(component_type_id);
1121 assert(component_type &&
1122 (component_type->AsInteger() || component_type->AsFloat() ||
1123 component_type->AsBool()) &&
1124 "|component_type_id| is invalid");
1125 assert(element_count >= 2 && element_count <= 4 &&
1126 "Precondition: component count must be in range [2, 4].");
1127 opt::analysis::Vector type(component_type, element_count);
1128 return ir_context->get_type_mgr()->GetId(&type);
1129 }
1130
MaybeGetStructType(opt::IRContext * ir_context,const std::vector<uint32_t> & component_type_ids)1131 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
1132 const std::vector<uint32_t>& component_type_ids) {
1133 for (auto& type_or_value : ir_context->types_values()) {
1134 if (type_or_value.opcode() != SpvOpTypeStruct ||
1135 type_or_value.NumInOperands() !=
1136 static_cast<uint32_t>(component_type_ids.size())) {
1137 continue;
1138 }
1139 bool all_components_match = true;
1140 for (uint32_t i = 0; i < component_type_ids.size(); i++) {
1141 if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
1142 all_components_match = false;
1143 break;
1144 }
1145 }
1146 if (all_components_match) {
1147 return type_or_value.result_id();
1148 }
1149 }
1150 return 0;
1151 }
1152
MaybeGetVoidType(opt::IRContext * ir_context)1153 uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
1154 opt::analysis::Void type;
1155 return ir_context->get_type_mgr()->GetId(&type);
1156 }
1157
MaybeGetZeroConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,uint32_t scalar_or_composite_type_id,bool is_irrelevant)1158 uint32_t MaybeGetZeroConstant(
1159 opt::IRContext* ir_context,
1160 const TransformationContext& transformation_context,
1161 uint32_t scalar_or_composite_type_id, bool is_irrelevant) {
1162 const auto* type_inst =
1163 ir_context->get_def_use_mgr()->GetDef(scalar_or_composite_type_id);
1164 assert(type_inst && "|scalar_or_composite_type_id| is invalid");
1165
1166 switch (type_inst->opcode()) {
1167 case SpvOpTypeBool:
1168 return MaybeGetBoolConstant(ir_context, transformation_context, false,
1169 is_irrelevant);
1170 case SpvOpTypeFloat:
1171 case SpvOpTypeInt: {
1172 const auto width = type_inst->GetSingleWordInOperand(0);
1173 std::vector<uint32_t> words = {0};
1174 if (width > 32) {
1175 words.push_back(0);
1176 }
1177
1178 return MaybeGetScalarConstant(ir_context, transformation_context, words,
1179 scalar_or_composite_type_id, is_irrelevant);
1180 }
1181 case SpvOpTypeStruct: {
1182 std::vector<uint32_t> component_ids;
1183 for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
1184 const auto component_type_id = type_inst->GetSingleWordInOperand(i);
1185
1186 auto component_id =
1187 MaybeGetZeroConstant(ir_context, transformation_context,
1188 component_type_id, is_irrelevant);
1189
1190 if (component_id == 0 && is_irrelevant) {
1191 // Irrelevant constants can use either relevant or irrelevant
1192 // constituents.
1193 component_id = MaybeGetZeroConstant(
1194 ir_context, transformation_context, component_type_id, false);
1195 }
1196
1197 if (component_id == 0) {
1198 return 0;
1199 }
1200
1201 component_ids.push_back(component_id);
1202 }
1203
1204 return MaybeGetCompositeConstant(
1205 ir_context, transformation_context, component_ids,
1206 scalar_or_composite_type_id, is_irrelevant);
1207 }
1208 case SpvOpTypeMatrix:
1209 case SpvOpTypeVector: {
1210 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1211
1212 auto component_id = MaybeGetZeroConstant(
1213 ir_context, transformation_context, component_type_id, is_irrelevant);
1214
1215 if (component_id == 0 && is_irrelevant) {
1216 // Irrelevant constants can use either relevant or irrelevant
1217 // constituents.
1218 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1219 component_type_id, false);
1220 }
1221
1222 if (component_id == 0) {
1223 return 0;
1224 }
1225
1226 const auto component_count = type_inst->GetSingleWordInOperand(1);
1227 return MaybeGetCompositeConstant(
1228 ir_context, transformation_context,
1229 std::vector<uint32_t>(component_count, component_id),
1230 scalar_or_composite_type_id, is_irrelevant);
1231 }
1232 case SpvOpTypeArray: {
1233 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1234
1235 auto component_id = MaybeGetZeroConstant(
1236 ir_context, transformation_context, component_type_id, is_irrelevant);
1237
1238 if (component_id == 0 && is_irrelevant) {
1239 // Irrelevant constants can use either relevant or irrelevant
1240 // constituents.
1241 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1242 component_type_id, false);
1243 }
1244
1245 if (component_id == 0) {
1246 return 0;
1247 }
1248
1249 return MaybeGetCompositeConstant(
1250 ir_context, transformation_context,
1251 std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
1252 component_id),
1253 scalar_or_composite_type_id, is_irrelevant);
1254 }
1255 default:
1256 assert(false && "Type is not supported");
1257 return 0;
1258 }
1259 }
1260
CanCreateConstant(opt::IRContext * ir_context,uint32_t type_id)1261 bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
1262 opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
1263 assert(type_instr != nullptr && "The type must exist.");
1264 assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
1265 "A type-generating opcode was expected.");
1266 switch (type_instr->opcode()) {
1267 case SpvOpTypeBool:
1268 case SpvOpTypeInt:
1269 case SpvOpTypeFloat:
1270 case SpvOpTypeMatrix:
1271 case SpvOpTypeVector:
1272 return true;
1273 case SpvOpTypeArray:
1274 return CanCreateConstant(ir_context,
1275 type_instr->GetSingleWordInOperand(0));
1276 case SpvOpTypeStruct:
1277 if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
1278 return false;
1279 }
1280 for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
1281 if (!CanCreateConstant(ir_context,
1282 type_instr->GetSingleWordInOperand(index))) {
1283 return false;
1284 }
1285 }
1286 return true;
1287 default:
1288 return false;
1289 }
1290 }
1291
MaybeGetScalarConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t scalar_type_id,bool is_irrelevant)1292 uint32_t MaybeGetScalarConstant(
1293 opt::IRContext* ir_context,
1294 const TransformationContext& transformation_context,
1295 const std::vector<uint32_t>& words, uint32_t scalar_type_id,
1296 bool is_irrelevant) {
1297 const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
1298 assert(type && "|scalar_type_id| is invalid");
1299
1300 if (const auto* int_type = type->AsInteger()) {
1301 return MaybeGetIntegerConstant(ir_context, transformation_context, words,
1302 int_type->width(), int_type->IsSigned(),
1303 is_irrelevant);
1304 } else if (const auto* float_type = type->AsFloat()) {
1305 return MaybeGetFloatConstant(ir_context, transformation_context, words,
1306 float_type->width(), is_irrelevant);
1307 } else {
1308 assert(type->AsBool() && words.size() == 1 &&
1309 "|scalar_type_id| doesn't represent a scalar type");
1310 return MaybeGetBoolConstant(ir_context, transformation_context, words[0],
1311 is_irrelevant);
1312 }
1313 }
1314
MaybeGetCompositeConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & component_ids,uint32_t composite_type_id,bool is_irrelevant)1315 uint32_t MaybeGetCompositeConstant(
1316 opt::IRContext* ir_context,
1317 const TransformationContext& transformation_context,
1318 const std::vector<uint32_t>& component_ids, uint32_t composite_type_id,
1319 bool is_irrelevant) {
1320 const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
1321 (void)type; // Make compilers happy in release mode.
1322 assert(IsCompositeType(type) && "|composite_type_id| is invalid");
1323
1324 for (const auto& inst : ir_context->types_values()) {
1325 if (inst.opcode() == SpvOpConstantComposite &&
1326 inst.type_id() == composite_type_id &&
1327 transformation_context.GetFactManager()->IdIsIrrelevant(
1328 inst.result_id()) == is_irrelevant &&
1329 inst.NumInOperands() == component_ids.size()) {
1330 bool is_match = true;
1331
1332 for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
1333 if (inst.GetSingleWordInOperand(i) != component_ids[i]) {
1334 is_match = false;
1335 break;
1336 }
1337 }
1338
1339 if (is_match) {
1340 return inst.result_id();
1341 }
1342 }
1343 }
1344
1345 return 0;
1346 }
1347
MaybeGetIntegerConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_signed,bool is_irrelevant)1348 uint32_t MaybeGetIntegerConstant(
1349 opt::IRContext* ir_context,
1350 const TransformationContext& transformation_context,
1351 const std::vector<uint32_t>& words, uint32_t width, bool is_signed,
1352 bool is_irrelevant) {
1353 if (auto type_id = MaybeGetIntegerType(ir_context, width, is_signed)) {
1354 return MaybeGetOpConstant(ir_context, transformation_context, words,
1355 type_id, is_irrelevant);
1356 }
1357
1358 return 0;
1359 }
1360
MaybeGetIntegerConstantFromValueAndType(opt::IRContext * ir_context,uint32_t value,uint32_t int_type_id)1361 uint32_t MaybeGetIntegerConstantFromValueAndType(opt::IRContext* ir_context,
1362 uint32_t value,
1363 uint32_t int_type_id) {
1364 auto int_type_inst = ir_context->get_def_use_mgr()->GetDef(int_type_id);
1365
1366 assert(int_type_inst && "The given type id must exist.");
1367
1368 auto int_type = ir_context->get_type_mgr()
1369 ->GetType(int_type_inst->result_id())
1370 ->AsInteger();
1371
1372 assert(int_type && int_type->width() == 32 &&
1373 "The given type id must correspond to an 32-bit integer type.");
1374
1375 opt::analysis::IntConstant constant(int_type, {value});
1376
1377 // Check that the constant exists in the module.
1378 if (!ir_context->get_constant_mgr()->FindConstant(&constant)) {
1379 return 0;
1380 }
1381
1382 return ir_context->get_constant_mgr()
1383 ->GetDefiningInstruction(&constant)
1384 ->result_id();
1385 }
1386
MaybeGetFloatConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_irrelevant)1387 uint32_t MaybeGetFloatConstant(
1388 opt::IRContext* ir_context,
1389 const TransformationContext& transformation_context,
1390 const std::vector<uint32_t>& words, uint32_t width, bool is_irrelevant) {
1391 if (auto type_id = MaybeGetFloatType(ir_context, width)) {
1392 return MaybeGetOpConstant(ir_context, transformation_context, words,
1393 type_id, is_irrelevant);
1394 }
1395
1396 return 0;
1397 }
1398
MaybeGetBoolConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,bool value,bool is_irrelevant)1399 uint32_t MaybeGetBoolConstant(
1400 opt::IRContext* ir_context,
1401 const TransformationContext& transformation_context, bool value,
1402 bool is_irrelevant) {
1403 if (auto type_id = MaybeGetBoolType(ir_context)) {
1404 for (const auto& inst : ir_context->types_values()) {
1405 if (inst.opcode() == (value ? SpvOpConstantTrue : SpvOpConstantFalse) &&
1406 inst.type_id() == type_id &&
1407 transformation_context.GetFactManager()->IdIsIrrelevant(
1408 inst.result_id()) == is_irrelevant) {
1409 return inst.result_id();
1410 }
1411 }
1412 }
1413
1414 return 0;
1415 }
1416
IntToWords(uint64_t value,uint32_t width,bool is_signed)1417 std::vector<uint32_t> IntToWords(uint64_t value, uint32_t width,
1418 bool is_signed) {
1419 assert(width <= 64 && "The bit width should not be more than 64 bits");
1420
1421 // Sign-extend or zero-extend the last |width| bits of |value|, depending on
1422 // |is_signed|.
1423 if (is_signed) {
1424 // Sign-extend by shifting left and then shifting right, interpreting the
1425 // integer as signed.
1426 value = static_cast<int64_t>(value << (64 - width)) >> (64 - width);
1427 } else {
1428 // Zero-extend by shifting left and then shifting right, interpreting the
1429 // integer as unsigned.
1430 value = (value << (64 - width)) >> (64 - width);
1431 }
1432
1433 std::vector<uint32_t> result;
1434 result.push_back(static_cast<uint32_t>(value));
1435 if (width > 32) {
1436 result.push_back(static_cast<uint32_t>(value >> 32));
1437 }
1438 return result;
1439 }
1440
TypesAreEqualUpToSign(opt::IRContext * ir_context,uint32_t type1_id,uint32_t type2_id)1441 bool TypesAreEqualUpToSign(opt::IRContext* ir_context, uint32_t type1_id,
1442 uint32_t type2_id) {
1443 if (type1_id == type2_id) {
1444 return true;
1445 }
1446
1447 auto type1 = ir_context->get_type_mgr()->GetType(type1_id);
1448 auto type2 = ir_context->get_type_mgr()->GetType(type2_id);
1449
1450 // Integer scalar types must have the same width
1451 if (type1->AsInteger() && type2->AsInteger()) {
1452 return type1->AsInteger()->width() == type2->AsInteger()->width();
1453 }
1454
1455 // Integer vector types must have the same number of components and their
1456 // component types must be integers with the same width.
1457 if (type1->AsVector() && type2->AsVector()) {
1458 auto component_type1 = type1->AsVector()->element_type()->AsInteger();
1459 auto component_type2 = type2->AsVector()->element_type()->AsInteger();
1460
1461 // Only check the component count and width if they are integer.
1462 if (component_type1 && component_type2) {
1463 return type1->AsVector()->element_count() ==
1464 type2->AsVector()->element_count() &&
1465 component_type1->width() == component_type2->width();
1466 }
1467 }
1468
1469 // In all other cases, the types cannot be considered equal.
1470 return false;
1471 }
1472
RepeatedUInt32PairToMap(const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> & data)1473 std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
1474 const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>& data) {
1475 std::map<uint32_t, uint32_t> result;
1476
1477 for (const auto& entry : data) {
1478 result[entry.first()] = entry.second();
1479 }
1480
1481 return result;
1482 }
1483
1484 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t,uint32_t> & data)1485 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
1486 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> result;
1487
1488 for (const auto& entry : data) {
1489 protobufs::UInt32Pair pair;
1490 pair.set_first(entry.first);
1491 pair.set_second(entry.second);
1492 *result.Add() = std::move(pair);
1493 }
1494
1495 return result;
1496 }
1497
GetLastInsertBeforeInstruction(opt::IRContext * ir_context,uint32_t block_id,SpvOp opcode)1498 opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
1499 uint32_t block_id,
1500 SpvOp opcode) {
1501 // CFG::block uses std::map::at which throws an exception when |block_id| is
1502 // invalid. The error message is unhelpful, though. Thus, we test that
1503 // |block_id| is valid here.
1504 const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
1505 (void)label_inst; // Make compilers happy in release mode.
1506 assert(label_inst && label_inst->opcode() == SpvOpLabel &&
1507 "|block_id| is invalid");
1508
1509 auto* block = ir_context->cfg()->block(block_id);
1510 auto it = block->rbegin();
1511 assert(it != block->rend() && "Basic block can't be empty");
1512
1513 if (block->GetMergeInst()) {
1514 ++it;
1515 assert(it != block->rend() &&
1516 "|block| must have at least two instructions:"
1517 "terminator and a merge instruction");
1518 }
1519
1520 return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
1521 }
1522
IdUseCanBeReplaced(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * use_instruction,uint32_t use_in_operand_index)1523 bool IdUseCanBeReplaced(opt::IRContext* ir_context,
1524 const TransformationContext& transformation_context,
1525 opt::Instruction* use_instruction,
1526 uint32_t use_in_operand_index) {
1527 if (spvOpcodeIsAccessChain(use_instruction->opcode()) &&
1528 use_in_operand_index > 0) {
1529 // A replacement for an irrelevant index in OpAccessChain must be clamped
1530 // first.
1531 if (transformation_context.GetFactManager()->IdIsIrrelevant(
1532 use_instruction->GetSingleWordInOperand(use_in_operand_index))) {
1533 return false;
1534 }
1535
1536 // This is an access chain index. If the (sub-)object being accessed by the
1537 // given index has struct type then we cannot replace the use, as it needs
1538 // to be an OpConstant.
1539
1540 // Get the top-level composite type that is being accessed.
1541 auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef(
1542 use_instruction->GetSingleWordInOperand(0));
1543 auto pointer_type =
1544 ir_context->get_type_mgr()->GetType(object_being_accessed->type_id());
1545 assert(pointer_type->AsPointer());
1546 auto composite_type_being_accessed =
1547 pointer_type->AsPointer()->pointee_type();
1548
1549 // Now walk the access chain, tracking the type of each sub-object of the
1550 // composite that is traversed, until the index of interest is reached.
1551 for (uint32_t index_in_operand = 1; index_in_operand < use_in_operand_index;
1552 index_in_operand++) {
1553 // For vectors, matrices and arrays, getting the type of the sub-object is
1554 // trivial. For the struct case, the sub-object type is field-sensitive,
1555 // and depends on the constant index that is used.
1556 if (composite_type_being_accessed->AsVector()) {
1557 composite_type_being_accessed =
1558 composite_type_being_accessed->AsVector()->element_type();
1559 } else if (composite_type_being_accessed->AsMatrix()) {
1560 composite_type_being_accessed =
1561 composite_type_being_accessed->AsMatrix()->element_type();
1562 } else if (composite_type_being_accessed->AsArray()) {
1563 composite_type_being_accessed =
1564 composite_type_being_accessed->AsArray()->element_type();
1565 } else if (composite_type_being_accessed->AsRuntimeArray()) {
1566 composite_type_being_accessed =
1567 composite_type_being_accessed->AsRuntimeArray()->element_type();
1568 } else {
1569 assert(composite_type_being_accessed->AsStruct());
1570 auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef(
1571 use_instruction->GetSingleWordInOperand(index_in_operand));
1572 assert(constant_index_instruction->opcode() == SpvOpConstant);
1573 uint32_t member_index =
1574 constant_index_instruction->GetSingleWordInOperand(0);
1575 composite_type_being_accessed =
1576 composite_type_being_accessed->AsStruct()
1577 ->element_types()[member_index];
1578 }
1579 }
1580
1581 // We have found the composite type being accessed by the index we are
1582 // considering replacing. If it is a struct, then we cannot do the
1583 // replacement as struct indices must be constants.
1584 if (composite_type_being_accessed->AsStruct()) {
1585 return false;
1586 }
1587 }
1588
1589 if (use_instruction->opcode() == SpvOpFunctionCall &&
1590 use_in_operand_index > 0) {
1591 // This is a function call argument. It is not allowed to have pointer
1592 // type.
1593
1594 // Get the definition of the function being called.
1595 auto function = ir_context->get_def_use_mgr()->GetDef(
1596 use_instruction->GetSingleWordInOperand(0));
1597 // From the function definition, get the function type.
1598 auto function_type = ir_context->get_def_use_mgr()->GetDef(
1599 function->GetSingleWordInOperand(1));
1600 // OpTypeFunction's 0-th input operand is the function return type, and the
1601 // function argument types follow. Because the arguments to OpFunctionCall
1602 // start from input operand 1, we can use |use_in_operand_index| to get the
1603 // type associated with this function argument.
1604 auto parameter_type = ir_context->get_type_mgr()->GetType(
1605 function_type->GetSingleWordInOperand(use_in_operand_index));
1606 if (parameter_type->AsPointer()) {
1607 return false;
1608 }
1609 }
1610
1611 if (use_instruction->opcode() == SpvOpImageTexelPointer &&
1612 use_in_operand_index == 2) {
1613 // The OpImageTexelPointer instruction has a Sample parameter that in some
1614 // situations must be an id for the value 0. To guard against disrupting
1615 // that requirement, we do not replace this argument to that instruction.
1616 return false;
1617 }
1618
1619 if (ir_context->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
1620 // With the Shader capability, memory scope and memory semantics operands
1621 // are required to be constants, so they cannot be replaced arbitrarily.
1622 switch (use_instruction->opcode()) {
1623 case SpvOpAtomicLoad:
1624 case SpvOpAtomicStore:
1625 case SpvOpAtomicExchange:
1626 case SpvOpAtomicIIncrement:
1627 case SpvOpAtomicIDecrement:
1628 case SpvOpAtomicIAdd:
1629 case SpvOpAtomicISub:
1630 case SpvOpAtomicSMin:
1631 case SpvOpAtomicUMin:
1632 case SpvOpAtomicSMax:
1633 case SpvOpAtomicUMax:
1634 case SpvOpAtomicAnd:
1635 case SpvOpAtomicOr:
1636 case SpvOpAtomicXor:
1637 if (use_in_operand_index == 1 || use_in_operand_index == 2) {
1638 return false;
1639 }
1640 break;
1641 case SpvOpAtomicCompareExchange:
1642 if (use_in_operand_index == 1 || use_in_operand_index == 2 ||
1643 use_in_operand_index == 3) {
1644 return false;
1645 }
1646 break;
1647 case SpvOpAtomicCompareExchangeWeak:
1648 case SpvOpAtomicFlagTestAndSet:
1649 case SpvOpAtomicFlagClear:
1650 case SpvOpAtomicFAddEXT:
1651 assert(false && "Not allowed with the Shader capability.");
1652 default:
1653 break;
1654 }
1655 }
1656
1657 return true;
1658 }
1659
MembersHaveBuiltInDecoration(opt::IRContext * ir_context,uint32_t struct_type_id)1660 bool MembersHaveBuiltInDecoration(opt::IRContext* ir_context,
1661 uint32_t struct_type_id) {
1662 const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(struct_type_id);
1663 assert(type_inst && type_inst->opcode() == SpvOpTypeStruct &&
1664 "|struct_type_id| is not a result id of an OpTypeStruct");
1665
1666 uint32_t builtin_count = 0;
1667 ir_context->get_def_use_mgr()->ForEachUser(
1668 type_inst,
1669 [struct_type_id, &builtin_count](const opt::Instruction* user) {
1670 if (user->opcode() == SpvOpMemberDecorate &&
1671 user->GetSingleWordInOperand(0) == struct_type_id &&
1672 static_cast<SpvDecoration>(user->GetSingleWordInOperand(2)) ==
1673 SpvDecorationBuiltIn) {
1674 ++builtin_count;
1675 }
1676 });
1677
1678 assert((builtin_count == 0 || builtin_count == type_inst->NumInOperands()) &&
1679 "The module is invalid: either none or all of the members of "
1680 "|struct_type_id| may be builtin");
1681
1682 return builtin_count != 0;
1683 }
1684
HasBlockOrBufferBlockDecoration(opt::IRContext * ir_context,uint32_t id)1685 bool HasBlockOrBufferBlockDecoration(opt::IRContext* ir_context, uint32_t id) {
1686 for (auto decoration : {SpvDecorationBlock, SpvDecorationBufferBlock}) {
1687 if (!ir_context->get_decoration_mgr()->WhileEachDecoration(
1688 id, decoration, [](const opt::Instruction & /*unused*/) -> bool {
1689 return false;
1690 })) {
1691 return true;
1692 }
1693 }
1694 return false;
1695 }
1696
SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(opt::BasicBlock * block_to_split,opt::Instruction * split_before)1697 bool SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(
1698 opt::BasicBlock* block_to_split, opt::Instruction* split_before) {
1699 std::set<uint32_t> sampled_image_result_ids;
1700 bool before_split = true;
1701
1702 // Check all the instructions in the block to split.
1703 for (auto& instruction : *block_to_split) {
1704 if (&instruction == &*split_before) {
1705 before_split = false;
1706 }
1707 if (before_split) {
1708 // If the instruction comes before the split and its opcode is
1709 // OpSampledImage, record its result id.
1710 if (instruction.opcode() == SpvOpSampledImage) {
1711 sampled_image_result_ids.insert(instruction.result_id());
1712 }
1713 } else {
1714 // If the instruction comes after the split, check if ids
1715 // corresponding to OpSampledImage instructions defined before the split
1716 // are used, and return true if they are.
1717 if (!instruction.WhileEachInId(
1718 [&sampled_image_result_ids](uint32_t* id) -> bool {
1719 return !sampled_image_result_ids.count(*id);
1720 })) {
1721 return true;
1722 }
1723 }
1724 }
1725
1726 // No usage that would be separated from the definition has been found.
1727 return false;
1728 }
1729
InstructionHasNoSideEffects(const opt::Instruction & instruction)1730 bool InstructionHasNoSideEffects(const opt::Instruction& instruction) {
1731 switch (instruction.opcode()) {
1732 case SpvOpUndef:
1733 case SpvOpAccessChain:
1734 case SpvOpInBoundsAccessChain:
1735 case SpvOpArrayLength:
1736 case SpvOpVectorExtractDynamic:
1737 case SpvOpVectorInsertDynamic:
1738 case SpvOpVectorShuffle:
1739 case SpvOpCompositeConstruct:
1740 case SpvOpCompositeExtract:
1741 case SpvOpCompositeInsert:
1742 case SpvOpCopyObject:
1743 case SpvOpTranspose:
1744 case SpvOpConvertFToU:
1745 case SpvOpConvertFToS:
1746 case SpvOpConvertSToF:
1747 case SpvOpConvertUToF:
1748 case SpvOpUConvert:
1749 case SpvOpSConvert:
1750 case SpvOpFConvert:
1751 case SpvOpQuantizeToF16:
1752 case SpvOpSatConvertSToU:
1753 case SpvOpSatConvertUToS:
1754 case SpvOpBitcast:
1755 case SpvOpSNegate:
1756 case SpvOpFNegate:
1757 case SpvOpIAdd:
1758 case SpvOpFAdd:
1759 case SpvOpISub:
1760 case SpvOpFSub:
1761 case SpvOpIMul:
1762 case SpvOpFMul:
1763 case SpvOpUDiv:
1764 case SpvOpSDiv:
1765 case SpvOpFDiv:
1766 case SpvOpUMod:
1767 case SpvOpSRem:
1768 case SpvOpSMod:
1769 case SpvOpFRem:
1770 case SpvOpFMod:
1771 case SpvOpVectorTimesScalar:
1772 case SpvOpMatrixTimesScalar:
1773 case SpvOpVectorTimesMatrix:
1774 case SpvOpMatrixTimesVector:
1775 case SpvOpMatrixTimesMatrix:
1776 case SpvOpOuterProduct:
1777 case SpvOpDot:
1778 case SpvOpIAddCarry:
1779 case SpvOpISubBorrow:
1780 case SpvOpUMulExtended:
1781 case SpvOpSMulExtended:
1782 case SpvOpAny:
1783 case SpvOpAll:
1784 case SpvOpIsNan:
1785 case SpvOpIsInf:
1786 case SpvOpIsFinite:
1787 case SpvOpIsNormal:
1788 case SpvOpSignBitSet:
1789 case SpvOpLessOrGreater:
1790 case SpvOpOrdered:
1791 case SpvOpUnordered:
1792 case SpvOpLogicalEqual:
1793 case SpvOpLogicalNotEqual:
1794 case SpvOpLogicalOr:
1795 case SpvOpLogicalAnd:
1796 case SpvOpLogicalNot:
1797 case SpvOpSelect:
1798 case SpvOpIEqual:
1799 case SpvOpINotEqual:
1800 case SpvOpUGreaterThan:
1801 case SpvOpSGreaterThan:
1802 case SpvOpUGreaterThanEqual:
1803 case SpvOpSGreaterThanEqual:
1804 case SpvOpULessThan:
1805 case SpvOpSLessThan:
1806 case SpvOpULessThanEqual:
1807 case SpvOpSLessThanEqual:
1808 case SpvOpFOrdEqual:
1809 case SpvOpFUnordEqual:
1810 case SpvOpFOrdNotEqual:
1811 case SpvOpFUnordNotEqual:
1812 case SpvOpFOrdLessThan:
1813 case SpvOpFUnordLessThan:
1814 case SpvOpFOrdGreaterThan:
1815 case SpvOpFUnordGreaterThan:
1816 case SpvOpFOrdLessThanEqual:
1817 case SpvOpFUnordLessThanEqual:
1818 case SpvOpFOrdGreaterThanEqual:
1819 case SpvOpFUnordGreaterThanEqual:
1820 case SpvOpShiftRightLogical:
1821 case SpvOpShiftRightArithmetic:
1822 case SpvOpShiftLeftLogical:
1823 case SpvOpBitwiseOr:
1824 case SpvOpBitwiseXor:
1825 case SpvOpBitwiseAnd:
1826 case SpvOpNot:
1827 case SpvOpBitFieldInsert:
1828 case SpvOpBitFieldSExtract:
1829 case SpvOpBitFieldUExtract:
1830 case SpvOpBitReverse:
1831 case SpvOpBitCount:
1832 case SpvOpCopyLogical:
1833 case SpvOpPhi:
1834 case SpvOpPtrEqual:
1835 case SpvOpPtrNotEqual:
1836 return true;
1837 default:
1838 return false;
1839 }
1840 }
1841
GetReachableReturnBlocks(opt::IRContext * ir_context,uint32_t function_id)1842 std::set<uint32_t> GetReachableReturnBlocks(opt::IRContext* ir_context,
1843 uint32_t function_id) {
1844 auto function = ir_context->GetFunction(function_id);
1845 assert(function && "The function |function_id| must exist.");
1846
1847 std::set<uint32_t> result;
1848
1849 ir_context->cfg()->ForEachBlockInPostOrder(function->entry().get(),
1850 [&result](opt::BasicBlock* block) {
1851 if (block->IsReturn()) {
1852 result.emplace(block->id());
1853 }
1854 });
1855
1856 return result;
1857 }
1858
NewTerminatorPreservesDominationRules(opt::IRContext * ir_context,uint32_t block_id,opt::Instruction new_terminator)1859 bool NewTerminatorPreservesDominationRules(opt::IRContext* ir_context,
1860 uint32_t block_id,
1861 opt::Instruction new_terminator) {
1862 auto* mutated_block = MaybeFindBlock(ir_context, block_id);
1863 assert(mutated_block && "|block_id| is invalid");
1864
1865 ChangeTerminatorRAII change_terminator_raii(mutated_block,
1866 std::move(new_terminator));
1867 opt::DominatorAnalysis dominator_analysis;
1868 dominator_analysis.InitializeTree(*ir_context->cfg(),
1869 mutated_block->GetParent());
1870
1871 // Check that each dominator appears before each dominated block.
1872 std::unordered_map<uint32_t, size_t> positions;
1873 for (const auto& block : *mutated_block->GetParent()) {
1874 positions[block.id()] = positions.size();
1875 }
1876
1877 std::queue<uint32_t> q({mutated_block->GetParent()->begin()->id()});
1878 std::unordered_set<uint32_t> visited;
1879 while (!q.empty()) {
1880 auto block = q.front();
1881 q.pop();
1882 visited.insert(block);
1883
1884 auto success = ir_context->cfg()->block(block)->WhileEachSuccessorLabel(
1885 [&positions, &visited, &dominator_analysis, block, &q](uint32_t id) {
1886 if (id == block) {
1887 // Handle the case when loop header and continue target are the same
1888 // block.
1889 return true;
1890 }
1891
1892 if (dominator_analysis.Dominates(block, id) &&
1893 positions[block] > positions[id]) {
1894 // |block| dominates |id| but appears after |id| - violates
1895 // domination rules.
1896 return false;
1897 }
1898
1899 if (!visited.count(id)) {
1900 q.push(id);
1901 }
1902
1903 return true;
1904 });
1905
1906 if (!success) {
1907 return false;
1908 }
1909 }
1910
1911 // For each instruction in the |block->GetParent()| function check whether
1912 // all its dependencies satisfy domination rules (i.e. all id operands
1913 // dominate that instruction).
1914 for (const auto& block : *mutated_block->GetParent()) {
1915 if (!ir_context->IsReachable(block)) {
1916 // If some block is not reachable then we don't need to worry about the
1917 // preservation of domination rules for its instructions.
1918 continue;
1919 }
1920
1921 for (const auto& inst : block) {
1922 for (uint32_t i = 0; i < inst.NumInOperands();
1923 i += inst.opcode() == SpvOpPhi ? 2 : 1) {
1924 const auto& operand = inst.GetInOperand(i);
1925 if (!spvIsInIdType(operand.type)) {
1926 continue;
1927 }
1928
1929 if (MaybeFindBlock(ir_context, operand.words[0])) {
1930 // Ignore operands that refer to OpLabel instructions.
1931 continue;
1932 }
1933
1934 const auto* dependency_block =
1935 ir_context->get_instr_block(operand.words[0]);
1936 if (!dependency_block) {
1937 // A global instruction always dominates all instructions in any
1938 // function.
1939 continue;
1940 }
1941
1942 auto domination_target_id = inst.opcode() == SpvOpPhi
1943 ? inst.GetSingleWordInOperand(i + 1)
1944 : block.id();
1945
1946 if (!dominator_analysis.Dominates(dependency_block->id(),
1947 domination_target_id)) {
1948 return false;
1949 }
1950 }
1951 }
1952 }
1953
1954 return true;
1955 }
1956
GetFunctionIterator(opt::IRContext * ir_context,uint32_t function_id)1957 opt::Module::iterator GetFunctionIterator(opt::IRContext* ir_context,
1958 uint32_t function_id) {
1959 return std::find_if(ir_context->module()->begin(),
1960 ir_context->module()->end(),
1961 [function_id](const opt::Function& f) {
1962 return f.result_id() == function_id;
1963 });
1964 }
1965
1966 } // namespace fuzzerutil
1967 } // namespace fuzz
1968 } // namespace spvtools
1969