1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/fuzzer_util.h"
16
17 #include <algorithm>
18 #include <unordered_set>
19
20 #include "source/opt/build_module.h"
21
22 namespace spvtools {
23 namespace fuzz {
24
25 namespace fuzzerutil {
26 namespace {
27
28 // A utility class that uses RAII to change and restore the terminator
29 // instruction of the |block|.
30 class ChangeTerminatorRAII {
31 public:
ChangeTerminatorRAII(opt::BasicBlock * block,opt::Instruction new_terminator)32 explicit ChangeTerminatorRAII(opt::BasicBlock* block,
33 opt::Instruction new_terminator)
34 : block_(block), old_terminator_(std::move(*block->terminator())) {
35 *block_->terminator() = std::move(new_terminator);
36 }
37
~ChangeTerminatorRAII()38 ~ChangeTerminatorRAII() {
39 *block_->terminator() = std::move(old_terminator_);
40 }
41
42 private:
43 opt::BasicBlock* block_;
44 opt::Instruction old_terminator_;
45 };
46
MaybeGetOpConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t type_id,bool is_irrelevant)47 uint32_t MaybeGetOpConstant(opt::IRContext* ir_context,
48 const TransformationContext& transformation_context,
49 const std::vector<uint32_t>& words,
50 uint32_t type_id, bool is_irrelevant) {
51 for (const auto& inst : ir_context->types_values()) {
52 if (inst.opcode() == SpvOpConstant && inst.type_id() == type_id &&
53 inst.GetInOperand(0).words == words &&
54 transformation_context.GetFactManager()->IdIsIrrelevant(
55 inst.result_id()) == is_irrelevant) {
56 return inst.result_id();
57 }
58 }
59
60 return 0;
61 }
62
63 } // namespace
64
65 const spvtools::MessageConsumer kSilentMessageConsumer =
66 [](spv_message_level_t, const char*, const spv_position_t&,
__anon77d6a2b50202(spv_message_level_t, const char*, const spv_position_t&, const char*) 67 const char*) -> void {};
68
BuildIRContext(spv_target_env target_env,const spvtools::MessageConsumer & message_consumer,const std::vector<uint32_t> & binary_in,spv_validator_options validator_options,std::unique_ptr<spvtools::opt::IRContext> * ir_context)69 bool BuildIRContext(spv_target_env target_env,
70 const spvtools::MessageConsumer& message_consumer,
71 const std::vector<uint32_t>& binary_in,
72 spv_validator_options validator_options,
73 std::unique_ptr<spvtools::opt::IRContext>* ir_context) {
74 SpirvTools tools(target_env);
75 tools.SetMessageConsumer(message_consumer);
76 if (!tools.IsValid()) {
77 message_consumer(SPV_MSG_ERROR, nullptr, {},
78 "Failed to create SPIRV-Tools interface; stopping.");
79 return false;
80 }
81
82 // Initial binary should be valid.
83 if (!tools.Validate(binary_in.data(), binary_in.size(), validator_options)) {
84 message_consumer(SPV_MSG_ERROR, nullptr, {},
85 "Initial binary is invalid; stopping.");
86 return false;
87 }
88
89 // Build the module from the input binary.
90 auto result = BuildModule(target_env, message_consumer, binary_in.data(),
91 binary_in.size());
92 assert(result && "IRContext must be valid");
93 *ir_context = std::move(result);
94 return true;
95 }
96
IsFreshId(opt::IRContext * context,uint32_t id)97 bool IsFreshId(opt::IRContext* context, uint32_t id) {
98 return !context->get_def_use_mgr()->GetDef(id);
99 }
100
UpdateModuleIdBound(opt::IRContext * context,uint32_t id)101 void UpdateModuleIdBound(opt::IRContext* context, uint32_t id) {
102 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/2541) consider the
103 // case where the maximum id bound is reached.
104 context->module()->SetIdBound(
105 std::max(context->module()->id_bound(), id + 1));
106 }
107
MaybeFindBlock(opt::IRContext * context,uint32_t maybe_block_id)108 opt::BasicBlock* MaybeFindBlock(opt::IRContext* context,
109 uint32_t maybe_block_id) {
110 auto inst = context->get_def_use_mgr()->GetDef(maybe_block_id);
111 if (inst == nullptr) {
112 // No instruction defining this id was found.
113 return nullptr;
114 }
115 if (inst->opcode() != SpvOpLabel) {
116 // The instruction defining the id is not a label, so it cannot be a block
117 // id.
118 return nullptr;
119 }
120 return context->cfg()->block(maybe_block_id);
121 }
122
PhiIdsOkForNewEdge(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)123 bool PhiIdsOkForNewEdge(
124 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
125 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
126 if (bb_from->IsSuccessor(bb_to)) {
127 // There is already an edge from |from_block| to |to_block|, so there is
128 // no need to extend OpPhi instructions. Do not allow phi ids to be
129 // present. This might turn out to be too strict; perhaps it would be OK
130 // just to ignore the ids in this case.
131 return phi_ids.empty();
132 }
133 // The edge would add a previously non-existent edge from |from_block| to
134 // |to_block|, so we go through the given phi ids and check that they exactly
135 // match the OpPhi instructions in |to_block|.
136 uint32_t phi_index = 0;
137 // An explicit loop, rather than applying a lambda to each OpPhi in |bb_to|,
138 // makes sense here because we need to increment |phi_index| for each OpPhi
139 // instruction.
140 for (auto& inst : *bb_to) {
141 if (inst.opcode() != SpvOpPhi) {
142 // The OpPhi instructions all occur at the start of the block; if we find
143 // a non-OpPhi then we have seen them all.
144 break;
145 }
146 if (phi_index == static_cast<uint32_t>(phi_ids.size())) {
147 // Not enough phi ids have been provided to account for the OpPhi
148 // instructions.
149 return false;
150 }
151 // Look for an instruction defining the next phi id.
152 opt::Instruction* phi_extension =
153 context->get_def_use_mgr()->GetDef(phi_ids[phi_index]);
154 if (!phi_extension) {
155 // The id given to extend this OpPhi does not exist.
156 return false;
157 }
158 if (phi_extension->type_id() != inst.type_id()) {
159 // The instruction given to extend this OpPhi either does not have a type
160 // or its type does not match that of the OpPhi.
161 return false;
162 }
163
164 if (context->get_instr_block(phi_extension)) {
165 // The instruction defining the phi id has an associated block (i.e., it
166 // is not a global value). Check whether its definition dominates the
167 // exit of |from_block|.
168 auto dominator_analysis =
169 context->GetDominatorAnalysis(bb_from->GetParent());
170 if (!dominator_analysis->Dominates(phi_extension,
171 bb_from->terminator())) {
172 // The given id is no good as its definition does not dominate the exit
173 // of |from_block|
174 return false;
175 }
176 }
177 phi_index++;
178 }
179 // We allow some of the ids provided for extending OpPhi instructions to be
180 // unused. Their presence does no harm, and requiring a perfect match may
181 // make transformations less likely to cleanly apply.
182 return true;
183 }
184
CreateUnreachableEdgeInstruction(opt::IRContext * ir_context,uint32_t bb_from_id,uint32_t bb_to_id,uint32_t bool_id)185 opt::Instruction CreateUnreachableEdgeInstruction(opt::IRContext* ir_context,
186 uint32_t bb_from_id,
187 uint32_t bb_to_id,
188 uint32_t bool_id) {
189 const auto* bb_from = MaybeFindBlock(ir_context, bb_from_id);
190 assert(bb_from && "|bb_from_id| is invalid");
191 assert(MaybeFindBlock(ir_context, bb_to_id) && "|bb_to_id| is invalid");
192 assert(bb_from->terminator()->opcode() == SpvOpBranch &&
193 "Precondition on terminator of bb_from is not satisfied");
194
195 // Get the id of the boolean constant to be used as the condition.
196 auto condition_inst = ir_context->get_def_use_mgr()->GetDef(bool_id);
197 assert(condition_inst &&
198 (condition_inst->opcode() == SpvOpConstantTrue ||
199 condition_inst->opcode() == SpvOpConstantFalse) &&
200 "|bool_id| is invalid");
201
202 auto condition_value = condition_inst->opcode() == SpvOpConstantTrue;
203 auto successor_id = bb_from->terminator()->GetSingleWordInOperand(0);
204
205 // Add the dead branch, by turning OpBranch into OpBranchConditional, and
206 // ordering the targets depending on whether the given boolean corresponds to
207 // true or false.
208 return opt::Instruction(
209 ir_context, SpvOpBranchConditional, 0, 0,
210 {{SPV_OPERAND_TYPE_ID, {bool_id}},
211 {SPV_OPERAND_TYPE_ID, {condition_value ? successor_id : bb_to_id}},
212 {SPV_OPERAND_TYPE_ID, {condition_value ? bb_to_id : successor_id}}});
213 }
214
AddUnreachableEdgeAndUpdateOpPhis(opt::IRContext * context,opt::BasicBlock * bb_from,opt::BasicBlock * bb_to,uint32_t bool_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & phi_ids)215 void AddUnreachableEdgeAndUpdateOpPhis(
216 opt::IRContext* context, opt::BasicBlock* bb_from, opt::BasicBlock* bb_to,
217 uint32_t bool_id,
218 const google::protobuf::RepeatedField<google::protobuf::uint32>& phi_ids) {
219 assert(PhiIdsOkForNewEdge(context, bb_from, bb_to, phi_ids) &&
220 "Precondition on phi_ids is not satisfied");
221
222 const bool from_to_edge_already_exists = bb_from->IsSuccessor(bb_to);
223 *bb_from->terminator() = CreateUnreachableEdgeInstruction(
224 context, bb_from->id(), bb_to->id(), bool_id);
225
226 // Update OpPhi instructions in the target block if this branch adds a
227 // previously non-existent edge from source to target.
228 if (!from_to_edge_already_exists) {
229 uint32_t phi_index = 0;
230 for (auto& inst : *bb_to) {
231 if (inst.opcode() != SpvOpPhi) {
232 break;
233 }
234 assert(phi_index < static_cast<uint32_t>(phi_ids.size()) &&
235 "There should be at least one phi id per OpPhi instruction.");
236 inst.AddOperand({SPV_OPERAND_TYPE_ID, {phi_ids[phi_index]}});
237 inst.AddOperand({SPV_OPERAND_TYPE_ID, {bb_from->id()}});
238 phi_index++;
239 }
240 }
241 }
242
BlockIsBackEdge(opt::IRContext * context,uint32_t block_id,uint32_t loop_header_id)243 bool BlockIsBackEdge(opt::IRContext* context, uint32_t block_id,
244 uint32_t loop_header_id) {
245 auto block = context->cfg()->block(block_id);
246 auto loop_header = context->cfg()->block(loop_header_id);
247
248 // |block| and |loop_header| must be defined, |loop_header| must be in fact
249 // loop header and |block| must branch to it.
250 if (!(block && loop_header && loop_header->IsLoopHeader() &&
251 block->IsSuccessor(loop_header))) {
252 return false;
253 }
254
255 // |block_id| must be reachable and be dominated by |loop_header|.
256 opt::DominatorAnalysis* dominator_analysis =
257 context->GetDominatorAnalysis(loop_header->GetParent());
258 return dominator_analysis->IsReachable(block_id) &&
259 dominator_analysis->Dominates(loop_header_id, block_id);
260 }
261
BlockIsInLoopContinueConstruct(opt::IRContext * context,uint32_t block_id,uint32_t maybe_loop_header_id)262 bool BlockIsInLoopContinueConstruct(opt::IRContext* context, uint32_t block_id,
263 uint32_t maybe_loop_header_id) {
264 // We deem a block to be part of a loop's continue construct if the loop's
265 // continue target dominates the block.
266 auto containing_construct_block = context->cfg()->block(maybe_loop_header_id);
267 if (containing_construct_block->IsLoopHeader()) {
268 auto continue_target = containing_construct_block->ContinueBlockId();
269 if (context->GetDominatorAnalysis(containing_construct_block->GetParent())
270 ->Dominates(continue_target, block_id)) {
271 return true;
272 }
273 }
274 return false;
275 }
276
GetIteratorForInstruction(opt::BasicBlock * block,const opt::Instruction * inst)277 opt::BasicBlock::iterator GetIteratorForInstruction(
278 opt::BasicBlock* block, const opt::Instruction* inst) {
279 for (auto inst_it = block->begin(); inst_it != block->end(); ++inst_it) {
280 if (inst == &*inst_it) {
281 return inst_it;
282 }
283 }
284 return block->end();
285 }
286
BlockIsReachableInItsFunction(opt::IRContext * context,opt::BasicBlock * bb)287 bool BlockIsReachableInItsFunction(opt::IRContext* context,
288 opt::BasicBlock* bb) {
289 auto enclosing_function = bb->GetParent();
290 return context->GetDominatorAnalysis(enclosing_function)
291 ->Dominates(enclosing_function->entry().get(), bb);
292 }
293
CanInsertOpcodeBeforeInstruction(SpvOp opcode,const opt::BasicBlock::iterator & instruction_in_block)294 bool CanInsertOpcodeBeforeInstruction(
295 SpvOp opcode, const opt::BasicBlock::iterator& instruction_in_block) {
296 if (instruction_in_block->PreviousNode() &&
297 (instruction_in_block->PreviousNode()->opcode() == SpvOpLoopMerge ||
298 instruction_in_block->PreviousNode()->opcode() == SpvOpSelectionMerge)) {
299 // We cannot insert directly after a merge instruction.
300 return false;
301 }
302 if (opcode != SpvOpVariable &&
303 instruction_in_block->opcode() == SpvOpVariable) {
304 // We cannot insert a non-OpVariable instruction directly before a
305 // variable; variables in a function must be contiguous in the entry block.
306 return false;
307 }
308 // We cannot insert a non-OpPhi instruction directly before an OpPhi, because
309 // OpPhi instructions need to be contiguous at the start of a block.
310 return opcode == SpvOpPhi || instruction_in_block->opcode() != SpvOpPhi;
311 }
312
CanMakeSynonymOf(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * inst)313 bool CanMakeSynonymOf(opt::IRContext* ir_context,
314 const TransformationContext& transformation_context,
315 opt::Instruction* inst) {
316 if (inst->opcode() == SpvOpSampledImage) {
317 // The SPIR-V data rules say that only very specific instructions may
318 // may consume the result id of an OpSampledImage, and this excludes the
319 // instructions that are used for making synonyms.
320 return false;
321 }
322 if (!inst->HasResultId()) {
323 // We can only make a synonym of an instruction that generates an id.
324 return false;
325 }
326 if (transformation_context.GetFactManager()->IdIsIrrelevant(
327 inst->result_id())) {
328 // An irrelevant id can't be a synonym of anything.
329 return false;
330 }
331 if (!inst->type_id()) {
332 // We can only make a synonym of an instruction that has a type.
333 return false;
334 }
335 auto type_inst = ir_context->get_def_use_mgr()->GetDef(inst->type_id());
336 if (type_inst->opcode() == SpvOpTypeVoid) {
337 // We only make synonyms of instructions that define objects, and an object
338 // cannot have void type.
339 return false;
340 }
341 if (type_inst->opcode() == SpvOpTypePointer) {
342 switch (inst->opcode()) {
343 case SpvOpConstantNull:
344 case SpvOpUndef:
345 // We disallow making synonyms of null or undefined pointers. This is
346 // to provide the property that if the original shader exhibited no bad
347 // pointer accesses, the transformed shader will not either.
348 return false;
349 default:
350 break;
351 }
352 }
353
354 // We do not make synonyms of objects that have decorations: if the synonym is
355 // not decorated analogously, using the original object vs. its synonymous
356 // form may not be equivalent.
357 return ir_context->get_decoration_mgr()
358 ->GetDecorationsFor(inst->result_id(), true)
359 .empty();
360 }
361
IsCompositeType(const opt::analysis::Type * type)362 bool IsCompositeType(const opt::analysis::Type* type) {
363 return type && (type->AsArray() || type->AsMatrix() || type->AsStruct() ||
364 type->AsVector());
365 }
366
RepeatedFieldToVector(const google::protobuf::RepeatedField<uint32_t> & repeated_field)367 std::vector<uint32_t> RepeatedFieldToVector(
368 const google::protobuf::RepeatedField<uint32_t>& repeated_field) {
369 std::vector<uint32_t> result;
370 for (auto i : repeated_field) {
371 result.push_back(i);
372 }
373 return result;
374 }
375
WalkOneCompositeTypeIndex(opt::IRContext * context,uint32_t base_object_type_id,uint32_t index)376 uint32_t WalkOneCompositeTypeIndex(opt::IRContext* context,
377 uint32_t base_object_type_id,
378 uint32_t index) {
379 auto should_be_composite_type =
380 context->get_def_use_mgr()->GetDef(base_object_type_id);
381 assert(should_be_composite_type && "The type should exist.");
382 switch (should_be_composite_type->opcode()) {
383 case SpvOpTypeArray: {
384 auto array_length = GetArraySize(*should_be_composite_type, context);
385 if (array_length == 0 || index >= array_length) {
386 return 0;
387 }
388 return should_be_composite_type->GetSingleWordInOperand(0);
389 }
390 case SpvOpTypeMatrix:
391 case SpvOpTypeVector: {
392 auto count = should_be_composite_type->GetSingleWordInOperand(1);
393 if (index >= count) {
394 return 0;
395 }
396 return should_be_composite_type->GetSingleWordInOperand(0);
397 }
398 case SpvOpTypeStruct: {
399 if (index >= GetNumberOfStructMembers(*should_be_composite_type)) {
400 return 0;
401 }
402 return should_be_composite_type->GetSingleWordInOperand(index);
403 }
404 default:
405 return 0;
406 }
407 }
408
WalkCompositeTypeIndices(opt::IRContext * context,uint32_t base_object_type_id,const google::protobuf::RepeatedField<google::protobuf::uint32> & indices)409 uint32_t WalkCompositeTypeIndices(
410 opt::IRContext* context, uint32_t base_object_type_id,
411 const google::protobuf::RepeatedField<google::protobuf::uint32>& indices) {
412 uint32_t sub_object_type_id = base_object_type_id;
413 for (auto index : indices) {
414 sub_object_type_id =
415 WalkOneCompositeTypeIndex(context, sub_object_type_id, index);
416 if (!sub_object_type_id) {
417 return 0;
418 }
419 }
420 return sub_object_type_id;
421 }
422
GetNumberOfStructMembers(const opt::Instruction & struct_type_instruction)423 uint32_t GetNumberOfStructMembers(
424 const opt::Instruction& struct_type_instruction) {
425 assert(struct_type_instruction.opcode() == SpvOpTypeStruct &&
426 "An OpTypeStruct instruction is required here.");
427 return struct_type_instruction.NumInOperands();
428 }
429
GetArraySize(const opt::Instruction & array_type_instruction,opt::IRContext * context)430 uint32_t GetArraySize(const opt::Instruction& array_type_instruction,
431 opt::IRContext* context) {
432 auto array_length_constant =
433 context->get_constant_mgr()
434 ->GetConstantFromInst(context->get_def_use_mgr()->GetDef(
435 array_type_instruction.GetSingleWordInOperand(1)))
436 ->AsIntConstant();
437 if (array_length_constant->words().size() != 1) {
438 return 0;
439 }
440 return array_length_constant->GetU32();
441 }
442
GetBoundForCompositeIndex(const opt::Instruction & composite_type_inst,opt::IRContext * ir_context)443 uint32_t GetBoundForCompositeIndex(const opt::Instruction& composite_type_inst,
444 opt::IRContext* ir_context) {
445 switch (composite_type_inst.opcode()) {
446 case SpvOpTypeArray:
447 return fuzzerutil::GetArraySize(composite_type_inst, ir_context);
448 case SpvOpTypeMatrix:
449 case SpvOpTypeVector:
450 return composite_type_inst.GetSingleWordInOperand(1);
451 case SpvOpTypeStruct: {
452 return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
453 }
454 case SpvOpTypeRuntimeArray:
455 assert(false &&
456 "GetBoundForCompositeIndex should not be invoked with an "
457 "OpTypeRuntimeArray, which does not have a static bound.");
458 return 0;
459 default:
460 assert(false && "Unknown composite type.");
461 return 0;
462 }
463 }
464
IsValid(const opt::IRContext * context,spv_validator_options validator_options,MessageConsumer consumer)465 bool IsValid(const opt::IRContext* context,
466 spv_validator_options validator_options,
467 MessageConsumer consumer) {
468 std::vector<uint32_t> binary;
469 context->module()->ToBinary(&binary, false);
470 SpirvTools tools(context->grammar().target_env());
471 tools.SetMessageConsumer(std::move(consumer));
472 return tools.Validate(binary.data(), binary.size(), validator_options);
473 }
474
IsValidAndWellFormed(const opt::IRContext * ir_context,spv_validator_options validator_options,MessageConsumer consumer)475 bool IsValidAndWellFormed(const opt::IRContext* ir_context,
476 spv_validator_options validator_options,
477 MessageConsumer consumer) {
478 if (!IsValid(ir_context, validator_options, consumer)) {
479 // Expression to dump |ir_context| to /data/temp/shader.spv:
480 // DumpShader(ir_context, "/data/temp/shader.spv")
481 consumer(SPV_MSG_INFO, nullptr, {},
482 "Module is invalid (set a breakpoint to inspect).");
483 return false;
484 }
485 // Check that all blocks in the module have appropriate parent functions.
486 for (auto& function : *ir_context->module()) {
487 for (auto& block : function) {
488 if (block.GetParent() == nullptr) {
489 std::stringstream ss;
490 ss << "Block " << block.id() << " has no parent; its parent should be "
491 << function.result_id() << " (set a breakpoint to inspect).";
492 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
493 return false;
494 }
495 if (block.GetParent() != &function) {
496 std::stringstream ss;
497 ss << "Block " << block.id() << " should have parent "
498 << function.result_id() << " but instead has parent "
499 << block.GetParent() << " (set a breakpoint to inspect).";
500 consumer(SPV_MSG_INFO, nullptr, {}, ss.str().c_str());
501 return false;
502 }
503 }
504 }
505
506 // Check that all instructions have distinct unique ids. We map each unique
507 // id to the first instruction it is observed to be associated with so that
508 // if we encounter a duplicate we have access to the previous instruction -
509 // this is a useful aid to debugging.
510 std::unordered_map<uint32_t, opt::Instruction*> unique_ids;
511 bool found_duplicate = false;
512 ir_context->module()->ForEachInst([&consumer, &found_duplicate,
513 &unique_ids](opt::Instruction* inst) {
514 if (unique_ids.count(inst->unique_id()) != 0) {
515 consumer(SPV_MSG_INFO, nullptr, {},
516 "Two instructions have the same unique id (set a breakpoint to "
517 "inspect).");
518 found_duplicate = true;
519 }
520 unique_ids.insert({inst->unique_id(), inst});
521 });
522 return !found_duplicate;
523 }
524
CloneIRContext(opt::IRContext * context)525 std::unique_ptr<opt::IRContext> CloneIRContext(opt::IRContext* context) {
526 std::vector<uint32_t> binary;
527 context->module()->ToBinary(&binary, false);
528 return BuildModule(context->grammar().target_env(), nullptr, binary.data(),
529 binary.size());
530 }
531
IsNonFunctionTypeId(opt::IRContext * ir_context,uint32_t id)532 bool IsNonFunctionTypeId(opt::IRContext* ir_context, uint32_t id) {
533 auto type = ir_context->get_type_mgr()->GetType(id);
534 return type && !type->AsFunction();
535 }
536
IsMergeOrContinue(opt::IRContext * ir_context,uint32_t block_id)537 bool IsMergeOrContinue(opt::IRContext* ir_context, uint32_t block_id) {
538 bool result = false;
539 ir_context->get_def_use_mgr()->WhileEachUse(
540 block_id,
541 [&result](const opt::Instruction* use_instruction,
542 uint32_t /*unused*/) -> bool {
543 switch (use_instruction->opcode()) {
544 case SpvOpLoopMerge:
545 case SpvOpSelectionMerge:
546 result = true;
547 return false;
548 default:
549 return true;
550 }
551 });
552 return result;
553 }
554
GetLoopFromMergeBlock(opt::IRContext * ir_context,uint32_t merge_block_id)555 uint32_t GetLoopFromMergeBlock(opt::IRContext* ir_context,
556 uint32_t merge_block_id) {
557 uint32_t result = 0;
558 ir_context->get_def_use_mgr()->WhileEachUse(
559 merge_block_id,
560 [ir_context, &result](opt::Instruction* use_instruction,
561 uint32_t use_index) -> bool {
562 switch (use_instruction->opcode()) {
563 case SpvOpLoopMerge:
564 // The merge block operand is the first operand in OpLoopMerge.
565 if (use_index == 0) {
566 result = ir_context->get_instr_block(use_instruction)->id();
567 return false;
568 }
569 return true;
570 default:
571 return true;
572 }
573 });
574 return result;
575 }
576
FindFunctionType(opt::IRContext * ir_context,const std::vector<uint32_t> & type_ids)577 uint32_t FindFunctionType(opt::IRContext* ir_context,
578 const std::vector<uint32_t>& type_ids) {
579 // Look through the existing types for a match.
580 for (auto& type_or_value : ir_context->types_values()) {
581 if (type_or_value.opcode() != SpvOpTypeFunction) {
582 // We are only interested in function types.
583 continue;
584 }
585 if (type_or_value.NumInOperands() != type_ids.size()) {
586 // Not a match: different numbers of arguments.
587 continue;
588 }
589 // Check whether the return type and argument types match.
590 bool input_operands_match = true;
591 for (uint32_t i = 0; i < type_or_value.NumInOperands(); i++) {
592 if (type_ids[i] != type_or_value.GetSingleWordInOperand(i)) {
593 input_operands_match = false;
594 break;
595 }
596 }
597 if (input_operands_match) {
598 // Everything matches.
599 return type_or_value.result_id();
600 }
601 }
602 // No match was found.
603 return 0;
604 }
605
GetFunctionType(opt::IRContext * context,const opt::Function * function)606 opt::Instruction* GetFunctionType(opt::IRContext* context,
607 const opt::Function* function) {
608 uint32_t type_id = function->DefInst().GetSingleWordInOperand(1);
609 return context->get_def_use_mgr()->GetDef(type_id);
610 }
611
FindFunction(opt::IRContext * ir_context,uint32_t function_id)612 opt::Function* FindFunction(opt::IRContext* ir_context, uint32_t function_id) {
613 for (auto& function : *ir_context->module()) {
614 if (function.result_id() == function_id) {
615 return &function;
616 }
617 }
618 return nullptr;
619 }
620
FunctionContainsOpKillOrUnreachable(const opt::Function & function)621 bool FunctionContainsOpKillOrUnreachable(const opt::Function& function) {
622 for (auto& block : function) {
623 if (block.terminator()->opcode() == SpvOpKill ||
624 block.terminator()->opcode() == SpvOpUnreachable) {
625 return true;
626 }
627 }
628 return false;
629 }
630
FunctionIsEntryPoint(opt::IRContext * context,uint32_t function_id)631 bool FunctionIsEntryPoint(opt::IRContext* context, uint32_t function_id) {
632 for (auto& entry_point : context->module()->entry_points()) {
633 if (entry_point.GetSingleWordInOperand(1) == function_id) {
634 return true;
635 }
636 }
637 return false;
638 }
639
IdIsAvailableAtUse(opt::IRContext * context,opt::Instruction * use_instruction,uint32_t use_input_operand_index,uint32_t id)640 bool IdIsAvailableAtUse(opt::IRContext* context,
641 opt::Instruction* use_instruction,
642 uint32_t use_input_operand_index, uint32_t id) {
643 assert(context->get_instr_block(use_instruction) &&
644 "|use_instruction| must be in a basic block");
645
646 auto defining_instruction = context->get_def_use_mgr()->GetDef(id);
647 auto enclosing_function =
648 context->get_instr_block(use_instruction)->GetParent();
649 // If the id a function parameter, it needs to be associated with the
650 // function containing the use.
651 if (defining_instruction->opcode() == SpvOpFunctionParameter) {
652 return InstructionIsFunctionParameter(defining_instruction,
653 enclosing_function);
654 }
655 if (!context->get_instr_block(id)) {
656 // The id must be at global scope.
657 return true;
658 }
659 if (defining_instruction == use_instruction) {
660 // It is not OK for a definition to use itself.
661 return false;
662 }
663 auto dominator_analysis = context->GetDominatorAnalysis(enclosing_function);
664 if (!dominator_analysis->IsReachable(
665 context->get_instr_block(use_instruction)) ||
666 !dominator_analysis->IsReachable(context->get_instr_block(id))) {
667 // Skip unreachable blocks.
668 return false;
669 }
670 if (use_instruction->opcode() == SpvOpPhi) {
671 // In the case where the use is an operand to OpPhi, it is actually the
672 // *parent* block associated with the operand that must be dominated by
673 // the synonym.
674 auto parent_block =
675 use_instruction->GetSingleWordInOperand(use_input_operand_index + 1);
676 return dominator_analysis->Dominates(
677 context->get_instr_block(defining_instruction)->id(), parent_block);
678 }
679 return dominator_analysis->Dominates(defining_instruction, use_instruction);
680 }
681
IdIsAvailableBeforeInstruction(opt::IRContext * context,opt::Instruction * instruction,uint32_t id)682 bool IdIsAvailableBeforeInstruction(opt::IRContext* context,
683 opt::Instruction* instruction,
684 uint32_t id) {
685 assert(context->get_instr_block(instruction) &&
686 "|instruction| must be in a basic block");
687
688 auto id_definition = context->get_def_use_mgr()->GetDef(id);
689 auto function_enclosing_instruction =
690 context->get_instr_block(instruction)->GetParent();
691 // If the id a function parameter, it needs to be associated with the
692 // function containing the instruction.
693 if (id_definition->opcode() == SpvOpFunctionParameter) {
694 return InstructionIsFunctionParameter(id_definition,
695 function_enclosing_instruction);
696 }
697 if (!context->get_instr_block(id)) {
698 // The id is at global scope.
699 return true;
700 }
701 if (id_definition == instruction) {
702 // The instruction is not available right before its own definition.
703 return false;
704 }
705 const auto* dominator_analysis =
706 context->GetDominatorAnalysis(function_enclosing_instruction);
707 if (dominator_analysis->IsReachable(context->get_instr_block(instruction)) &&
708 dominator_analysis->IsReachable(context->get_instr_block(id)) &&
709 dominator_analysis->Dominates(id_definition, instruction)) {
710 // The id's definition dominates the instruction, and both the definition
711 // and the instruction are in reachable blocks, thus the id is available at
712 // the instruction.
713 return true;
714 }
715 if (id_definition->opcode() == SpvOpVariable &&
716 function_enclosing_instruction ==
717 context->get_instr_block(id)->GetParent()) {
718 assert(!dominator_analysis->IsReachable(
719 context->get_instr_block(instruction)) &&
720 "If the instruction were in a reachable block we should already "
721 "have returned true.");
722 // The id is a variable and it is in the same function as |instruction|.
723 // This is OK despite |instruction| being unreachable.
724 return true;
725 }
726 return false;
727 }
728
InstructionIsFunctionParameter(opt::Instruction * instruction,opt::Function * function)729 bool InstructionIsFunctionParameter(opt::Instruction* instruction,
730 opt::Function* function) {
731 if (instruction->opcode() != SpvOpFunctionParameter) {
732 return false;
733 }
734 bool found_parameter = false;
735 function->ForEachParam(
736 [instruction, &found_parameter](opt::Instruction* param) {
737 if (param == instruction) {
738 found_parameter = true;
739 }
740 });
741 return found_parameter;
742 }
743
GetTypeId(opt::IRContext * context,uint32_t result_id)744 uint32_t GetTypeId(opt::IRContext* context, uint32_t result_id) {
745 const auto* inst = context->get_def_use_mgr()->GetDef(result_id);
746 assert(inst && "|result_id| is invalid");
747 return inst->type_id();
748 }
749
GetPointeeTypeIdFromPointerType(opt::Instruction * pointer_type_inst)750 uint32_t GetPointeeTypeIdFromPointerType(opt::Instruction* pointer_type_inst) {
751 assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
752 "Precondition: |pointer_type_inst| must be OpTypePointer.");
753 return pointer_type_inst->GetSingleWordInOperand(1);
754 }
755
GetPointeeTypeIdFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)756 uint32_t GetPointeeTypeIdFromPointerType(opt::IRContext* context,
757 uint32_t pointer_type_id) {
758 return GetPointeeTypeIdFromPointerType(
759 context->get_def_use_mgr()->GetDef(pointer_type_id));
760 }
761
GetStorageClassFromPointerType(opt::Instruction * pointer_type_inst)762 SpvStorageClass GetStorageClassFromPointerType(
763 opt::Instruction* pointer_type_inst) {
764 assert(pointer_type_inst && pointer_type_inst->opcode() == SpvOpTypePointer &&
765 "Precondition: |pointer_type_inst| must be OpTypePointer.");
766 return static_cast<SpvStorageClass>(
767 pointer_type_inst->GetSingleWordInOperand(0));
768 }
769
GetStorageClassFromPointerType(opt::IRContext * context,uint32_t pointer_type_id)770 SpvStorageClass GetStorageClassFromPointerType(opt::IRContext* context,
771 uint32_t pointer_type_id) {
772 return GetStorageClassFromPointerType(
773 context->get_def_use_mgr()->GetDef(pointer_type_id));
774 }
775
MaybeGetPointerType(opt::IRContext * context,uint32_t pointee_type_id,SpvStorageClass storage_class)776 uint32_t MaybeGetPointerType(opt::IRContext* context, uint32_t pointee_type_id,
777 SpvStorageClass storage_class) {
778 for (auto& inst : context->types_values()) {
779 switch (inst.opcode()) {
780 case SpvOpTypePointer:
781 if (inst.GetSingleWordInOperand(0) == storage_class &&
782 inst.GetSingleWordInOperand(1) == pointee_type_id) {
783 return inst.result_id();
784 }
785 break;
786 default:
787 break;
788 }
789 }
790 return 0;
791 }
792
InOperandIndexFromOperandIndex(const opt::Instruction & inst,uint32_t absolute_index)793 uint32_t InOperandIndexFromOperandIndex(const opt::Instruction& inst,
794 uint32_t absolute_index) {
795 // Subtract the number of non-input operands from the index
796 return absolute_index - inst.NumOperands() + inst.NumInOperands();
797 }
798
IsNullConstantSupported(const opt::analysis::Type & type)799 bool IsNullConstantSupported(const opt::analysis::Type& type) {
800 return type.AsBool() || type.AsInteger() || type.AsFloat() ||
801 type.AsMatrix() || type.AsVector() || type.AsArray() ||
802 type.AsStruct() || type.AsPointer() || type.AsEvent() ||
803 type.AsDeviceEvent() || type.AsReserveId() || type.AsQueue();
804 }
805
GlobalVariablesMustBeDeclaredInEntryPointInterfaces(const opt::IRContext * ir_context)806 bool GlobalVariablesMustBeDeclaredInEntryPointInterfaces(
807 const opt::IRContext* ir_context) {
808 // TODO(afd): We capture the environments for which this requirement holds.
809 // The check should be refined on demand for other target environments.
810 switch (ir_context->grammar().target_env()) {
811 case SPV_ENV_UNIVERSAL_1_0:
812 case SPV_ENV_UNIVERSAL_1_1:
813 case SPV_ENV_UNIVERSAL_1_2:
814 case SPV_ENV_UNIVERSAL_1_3:
815 case SPV_ENV_VULKAN_1_0:
816 case SPV_ENV_VULKAN_1_1:
817 return false;
818 default:
819 return true;
820 }
821 }
822
AddVariableIdToEntryPointInterfaces(opt::IRContext * context,uint32_t id)823 void AddVariableIdToEntryPointInterfaces(opt::IRContext* context, uint32_t id) {
824 if (GlobalVariablesMustBeDeclaredInEntryPointInterfaces(context)) {
825 // Conservatively add this global to the interface of every entry point in
826 // the module. This means that the global is available for other
827 // transformations to use.
828 //
829 // A downside of this is that the global will be in the interface even if it
830 // ends up never being used.
831 //
832 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3111) revisit
833 // this if a more thorough approach to entry point interfaces is taken.
834 for (auto& entry_point : context->module()->entry_points()) {
835 entry_point.AddOperand({SPV_OPERAND_TYPE_ID, {id}});
836 }
837 }
838 }
839
AddGlobalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,SpvStorageClass storage_class,uint32_t initializer_id)840 opt::Instruction* AddGlobalVariable(opt::IRContext* context, uint32_t result_id,
841 uint32_t type_id,
842 SpvStorageClass storage_class,
843 uint32_t initializer_id) {
844 // Check various preconditions.
845 assert(result_id != 0 && "Result id can't be 0");
846
847 assert((storage_class == SpvStorageClassPrivate ||
848 storage_class == SpvStorageClassWorkgroup) &&
849 "Variable's storage class must be either Private or Workgroup");
850
851 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
852 (void)type_inst; // Variable becomes unused in release mode.
853 assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
854 GetStorageClassFromPointerType(type_inst) == storage_class &&
855 "Variable's type is invalid");
856
857 if (storage_class == SpvStorageClassWorkgroup) {
858 assert(initializer_id == 0);
859 }
860
861 if (initializer_id != 0) {
862 const auto* constant_inst =
863 context->get_def_use_mgr()->GetDef(initializer_id);
864 (void)constant_inst; // Variable becomes unused in release mode.
865 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
866 GetPointeeTypeIdFromPointerType(type_inst) ==
867 constant_inst->type_id() &&
868 "Initializer is invalid");
869 }
870
871 opt::Instruction::OperandList operands = {
872 {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast<uint32_t>(storage_class)}}};
873
874 if (initializer_id) {
875 operands.push_back({SPV_OPERAND_TYPE_ID, {initializer_id}});
876 }
877
878 auto new_instruction = MakeUnique<opt::Instruction>(
879 context, SpvOpVariable, type_id, result_id, std::move(operands));
880 auto result = new_instruction.get();
881 context->module()->AddGlobalValue(std::move(new_instruction));
882
883 AddVariableIdToEntryPointInterfaces(context, result_id);
884 UpdateModuleIdBound(context, result_id);
885
886 return result;
887 }
888
AddLocalVariable(opt::IRContext * context,uint32_t result_id,uint32_t type_id,uint32_t function_id,uint32_t initializer_id)889 opt::Instruction* AddLocalVariable(opt::IRContext* context, uint32_t result_id,
890 uint32_t type_id, uint32_t function_id,
891 uint32_t initializer_id) {
892 // Check various preconditions.
893 assert(result_id != 0 && "Result id can't be 0");
894
895 auto* type_inst = context->get_def_use_mgr()->GetDef(type_id);
896 (void)type_inst; // Variable becomes unused in release mode.
897 assert(type_inst && type_inst->opcode() == SpvOpTypePointer &&
898 GetStorageClassFromPointerType(type_inst) == SpvStorageClassFunction &&
899 "Variable's type is invalid");
900
901 const auto* constant_inst =
902 context->get_def_use_mgr()->GetDef(initializer_id);
903 (void)constant_inst; // Variable becomes unused in release mode.
904 assert(constant_inst && spvOpcodeIsConstant(constant_inst->opcode()) &&
905 GetPointeeTypeIdFromPointerType(type_inst) ==
906 constant_inst->type_id() &&
907 "Initializer is invalid");
908
909 auto* function = FindFunction(context, function_id);
910 assert(function && "Function id is invalid");
911
912 auto new_instruction = MakeUnique<opt::Instruction>(
913 context, SpvOpVariable, type_id, result_id,
914 opt::Instruction::OperandList{
915 {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
916 {SPV_OPERAND_TYPE_ID, {initializer_id}}});
917 auto result = new_instruction.get();
918 function->begin()->begin()->InsertBefore(std::move(new_instruction));
919
920 UpdateModuleIdBound(context, result_id);
921
922 return result;
923 }
924
HasDuplicates(const std::vector<uint32_t> & arr)925 bool HasDuplicates(const std::vector<uint32_t>& arr) {
926 return std::unordered_set<uint32_t>(arr.begin(), arr.end()).size() !=
927 arr.size();
928 }
929
IsPermutationOfRange(const std::vector<uint32_t> & arr,uint32_t lo,uint32_t hi)930 bool IsPermutationOfRange(const std::vector<uint32_t>& arr, uint32_t lo,
931 uint32_t hi) {
932 if (arr.empty()) {
933 return lo > hi;
934 }
935
936 if (HasDuplicates(arr)) {
937 return false;
938 }
939
940 auto min_max = std::minmax_element(arr.begin(), arr.end());
941 return arr.size() == hi - lo + 1 && *min_max.first == lo &&
942 *min_max.second == hi;
943 }
944
GetParameters(opt::IRContext * ir_context,uint32_t function_id)945 std::vector<opt::Instruction*> GetParameters(opt::IRContext* ir_context,
946 uint32_t function_id) {
947 auto* function = FindFunction(ir_context, function_id);
948 assert(function && "|function_id| is invalid");
949
950 std::vector<opt::Instruction*> result;
951 function->ForEachParam(
952 [&result](opt::Instruction* inst) { result.push_back(inst); });
953
954 return result;
955 }
956
RemoveParameter(opt::IRContext * ir_context,uint32_t parameter_id)957 void RemoveParameter(opt::IRContext* ir_context, uint32_t parameter_id) {
958 auto* function = GetFunctionFromParameterId(ir_context, parameter_id);
959 assert(function && "|parameter_id| is invalid");
960 assert(!FunctionIsEntryPoint(ir_context, function->result_id()) &&
961 "Can't remove parameter from an entry point function");
962
963 function->RemoveParameter(parameter_id);
964
965 // We've just removed parameters from the function and cleared their memory.
966 // Make sure analyses have no dangling pointers.
967 ir_context->InvalidateAnalysesExceptFor(
968 opt::IRContext::Analysis::kAnalysisNone);
969 }
970
GetCallers(opt::IRContext * ir_context,uint32_t function_id)971 std::vector<opt::Instruction*> GetCallers(opt::IRContext* ir_context,
972 uint32_t function_id) {
973 assert(FindFunction(ir_context, function_id) &&
974 "|function_id| is not a result id of a function");
975
976 std::vector<opt::Instruction*> result;
977 ir_context->get_def_use_mgr()->ForEachUser(
978 function_id, [&result, function_id](opt::Instruction* inst) {
979 if (inst->opcode() == SpvOpFunctionCall &&
980 inst->GetSingleWordInOperand(0) == function_id) {
981 result.push_back(inst);
982 }
983 });
984
985 return result;
986 }
987
GetFunctionFromParameterId(opt::IRContext * ir_context,uint32_t param_id)988 opt::Function* GetFunctionFromParameterId(opt::IRContext* ir_context,
989 uint32_t param_id) {
990 auto* param_inst = ir_context->get_def_use_mgr()->GetDef(param_id);
991 assert(param_inst && "Parameter id is invalid");
992
993 for (auto& function : *ir_context->module()) {
994 if (InstructionIsFunctionParameter(param_inst, &function)) {
995 return &function;
996 }
997 }
998
999 return nullptr;
1000 }
1001
UpdateFunctionType(opt::IRContext * ir_context,uint32_t function_id,uint32_t new_function_type_result_id,uint32_t return_type_id,const std::vector<uint32_t> & parameter_type_ids)1002 uint32_t UpdateFunctionType(opt::IRContext* ir_context, uint32_t function_id,
1003 uint32_t new_function_type_result_id,
1004 uint32_t return_type_id,
1005 const std::vector<uint32_t>& parameter_type_ids) {
1006 // Check some initial constraints.
1007 assert(ir_context->get_type_mgr()->GetType(return_type_id) &&
1008 "Return type is invalid");
1009 for (auto id : parameter_type_ids) {
1010 const auto* type = ir_context->get_type_mgr()->GetType(id);
1011 (void)type; // Make compilers happy in release mode.
1012 // Parameters can't be OpTypeVoid.
1013 assert(type && !type->AsVoid() && "Parameter has invalid type");
1014 }
1015
1016 auto* function = FindFunction(ir_context, function_id);
1017 assert(function && "|function_id| is invalid");
1018
1019 auto* old_function_type = GetFunctionType(ir_context, function);
1020 assert(old_function_type && "Function has invalid type");
1021
1022 std::vector<uint32_t> operand_ids = {return_type_id};
1023 operand_ids.insert(operand_ids.end(), parameter_type_ids.begin(),
1024 parameter_type_ids.end());
1025
1026 // A trivial case - we change nothing.
1027 if (FindFunctionType(ir_context, operand_ids) ==
1028 old_function_type->result_id()) {
1029 return old_function_type->result_id();
1030 }
1031
1032 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1 &&
1033 FindFunctionType(ir_context, operand_ids) == 0) {
1034 // We can change |old_function_type| only if it's used once in the module
1035 // and we are certain we won't create a duplicate as a result of the change.
1036
1037 // Update |old_function_type| in-place.
1038 opt::Instruction::OperandList operands;
1039 for (auto id : operand_ids) {
1040 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1041 }
1042
1043 old_function_type->SetInOperands(std::move(operands));
1044
1045 // |operands| may depend on result ids defined below the |old_function_type|
1046 // in the module.
1047 old_function_type->RemoveFromList();
1048 ir_context->AddType(std::unique_ptr<opt::Instruction>(old_function_type));
1049 return old_function_type->result_id();
1050 } else {
1051 // We can't modify the |old_function_type| so we have to either use an
1052 // existing one or create a new one.
1053 auto type_id = FindOrCreateFunctionType(
1054 ir_context, new_function_type_result_id, operand_ids);
1055 assert(type_id != old_function_type->result_id() &&
1056 "We should've handled this case above");
1057
1058 function->DefInst().SetInOperand(1, {type_id});
1059
1060 // DefUseManager hasn't been updated yet, so if the following condition is
1061 // true, then |old_function_type| will have no users when this function
1062 // returns. We might as well remove it.
1063 if (ir_context->get_def_use_mgr()->NumUsers(old_function_type) == 1) {
1064 ir_context->KillInst(old_function_type);
1065 }
1066
1067 return type_id;
1068 }
1069 }
1070
AddFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1071 void AddFunctionType(opt::IRContext* ir_context, uint32_t result_id,
1072 const std::vector<uint32_t>& type_ids) {
1073 assert(result_id != 0 && "Result id can't be 0");
1074 assert(!type_ids.empty() &&
1075 "OpTypeFunction always has at least one operand - function's return "
1076 "type");
1077 assert(IsNonFunctionTypeId(ir_context, type_ids[0]) &&
1078 "Return type must not be a function");
1079
1080 for (size_t i = 1; i < type_ids.size(); ++i) {
1081 const auto* param_type = ir_context->get_type_mgr()->GetType(type_ids[i]);
1082 (void)param_type; // Make compiler happy in release mode.
1083 assert(param_type && !param_type->AsVoid() && !param_type->AsFunction() &&
1084 "Function parameter can't have a function or void type");
1085 }
1086
1087 opt::Instruction::OperandList operands;
1088 operands.reserve(type_ids.size());
1089 for (auto id : type_ids) {
1090 operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
1091 }
1092
1093 ir_context->AddType(MakeUnique<opt::Instruction>(
1094 ir_context, SpvOpTypeFunction, 0, result_id, std::move(operands)));
1095
1096 UpdateModuleIdBound(ir_context, result_id);
1097 }
1098
FindOrCreateFunctionType(opt::IRContext * ir_context,uint32_t result_id,const std::vector<uint32_t> & type_ids)1099 uint32_t FindOrCreateFunctionType(opt::IRContext* ir_context,
1100 uint32_t result_id,
1101 const std::vector<uint32_t>& type_ids) {
1102 if (auto existing_id = FindFunctionType(ir_context, type_ids)) {
1103 return existing_id;
1104 }
1105 AddFunctionType(ir_context, result_id, type_ids);
1106 return result_id;
1107 }
1108
MaybeGetIntegerType(opt::IRContext * ir_context,uint32_t width,bool is_signed)1109 uint32_t MaybeGetIntegerType(opt::IRContext* ir_context, uint32_t width,
1110 bool is_signed) {
1111 opt::analysis::Integer type(width, is_signed);
1112 return ir_context->get_type_mgr()->GetId(&type);
1113 }
1114
MaybeGetFloatType(opt::IRContext * ir_context,uint32_t width)1115 uint32_t MaybeGetFloatType(opt::IRContext* ir_context, uint32_t width) {
1116 opt::analysis::Float type(width);
1117 return ir_context->get_type_mgr()->GetId(&type);
1118 }
1119
MaybeGetBoolType(opt::IRContext * ir_context)1120 uint32_t MaybeGetBoolType(opt::IRContext* ir_context) {
1121 opt::analysis::Bool type;
1122 return ir_context->get_type_mgr()->GetId(&type);
1123 }
1124
MaybeGetVectorType(opt::IRContext * ir_context,uint32_t component_type_id,uint32_t element_count)1125 uint32_t MaybeGetVectorType(opt::IRContext* ir_context,
1126 uint32_t component_type_id,
1127 uint32_t element_count) {
1128 const auto* component_type =
1129 ir_context->get_type_mgr()->GetType(component_type_id);
1130 assert(component_type &&
1131 (component_type->AsInteger() || component_type->AsFloat() ||
1132 component_type->AsBool()) &&
1133 "|component_type_id| is invalid");
1134 assert(element_count >= 2 && element_count <= 4 &&
1135 "Precondition: component count must be in range [2, 4].");
1136 opt::analysis::Vector type(component_type, element_count);
1137 return ir_context->get_type_mgr()->GetId(&type);
1138 }
1139
MaybeGetStructType(opt::IRContext * ir_context,const std::vector<uint32_t> & component_type_ids)1140 uint32_t MaybeGetStructType(opt::IRContext* ir_context,
1141 const std::vector<uint32_t>& component_type_ids) {
1142 for (auto& type_or_value : ir_context->types_values()) {
1143 if (type_or_value.opcode() != SpvOpTypeStruct ||
1144 type_or_value.NumInOperands() !=
1145 static_cast<uint32_t>(component_type_ids.size())) {
1146 continue;
1147 }
1148 bool all_components_match = true;
1149 for (uint32_t i = 0; i < component_type_ids.size(); i++) {
1150 if (type_or_value.GetSingleWordInOperand(i) != component_type_ids[i]) {
1151 all_components_match = false;
1152 break;
1153 }
1154 }
1155 if (all_components_match) {
1156 return type_or_value.result_id();
1157 }
1158 }
1159 return 0;
1160 }
1161
MaybeGetVoidType(opt::IRContext * ir_context)1162 uint32_t MaybeGetVoidType(opt::IRContext* ir_context) {
1163 opt::analysis::Void type;
1164 return ir_context->get_type_mgr()->GetId(&type);
1165 }
1166
MaybeGetZeroConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,uint32_t scalar_or_composite_type_id,bool is_irrelevant)1167 uint32_t MaybeGetZeroConstant(
1168 opt::IRContext* ir_context,
1169 const TransformationContext& transformation_context,
1170 uint32_t scalar_or_composite_type_id, bool is_irrelevant) {
1171 const auto* type_inst =
1172 ir_context->get_def_use_mgr()->GetDef(scalar_or_composite_type_id);
1173 assert(type_inst && "|scalar_or_composite_type_id| is invalid");
1174
1175 switch (type_inst->opcode()) {
1176 case SpvOpTypeBool:
1177 return MaybeGetBoolConstant(ir_context, transformation_context, false,
1178 is_irrelevant);
1179 case SpvOpTypeFloat:
1180 case SpvOpTypeInt: {
1181 const auto width = type_inst->GetSingleWordInOperand(0);
1182 std::vector<uint32_t> words = {0};
1183 if (width > 32) {
1184 words.push_back(0);
1185 }
1186
1187 return MaybeGetScalarConstant(ir_context, transformation_context, words,
1188 scalar_or_composite_type_id, is_irrelevant);
1189 }
1190 case SpvOpTypeStruct: {
1191 std::vector<uint32_t> component_ids;
1192 for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) {
1193 const auto component_type_id = type_inst->GetSingleWordInOperand(i);
1194
1195 auto component_id =
1196 MaybeGetZeroConstant(ir_context, transformation_context,
1197 component_type_id, is_irrelevant);
1198
1199 if (component_id == 0 && is_irrelevant) {
1200 // Irrelevant constants can use either relevant or irrelevant
1201 // constituents.
1202 component_id = MaybeGetZeroConstant(
1203 ir_context, transformation_context, component_type_id, false);
1204 }
1205
1206 if (component_id == 0) {
1207 return 0;
1208 }
1209
1210 component_ids.push_back(component_id);
1211 }
1212
1213 return MaybeGetCompositeConstant(
1214 ir_context, transformation_context, component_ids,
1215 scalar_or_composite_type_id, is_irrelevant);
1216 }
1217 case SpvOpTypeMatrix:
1218 case SpvOpTypeVector: {
1219 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1220
1221 auto component_id = MaybeGetZeroConstant(
1222 ir_context, transformation_context, component_type_id, is_irrelevant);
1223
1224 if (component_id == 0 && is_irrelevant) {
1225 // Irrelevant constants can use either relevant or irrelevant
1226 // constituents.
1227 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1228 component_type_id, false);
1229 }
1230
1231 if (component_id == 0) {
1232 return 0;
1233 }
1234
1235 const auto component_count = type_inst->GetSingleWordInOperand(1);
1236 return MaybeGetCompositeConstant(
1237 ir_context, transformation_context,
1238 std::vector<uint32_t>(component_count, component_id),
1239 scalar_or_composite_type_id, is_irrelevant);
1240 }
1241 case SpvOpTypeArray: {
1242 const auto component_type_id = type_inst->GetSingleWordInOperand(0);
1243
1244 auto component_id = MaybeGetZeroConstant(
1245 ir_context, transformation_context, component_type_id, is_irrelevant);
1246
1247 if (component_id == 0 && is_irrelevant) {
1248 // Irrelevant constants can use either relevant or irrelevant
1249 // constituents.
1250 component_id = MaybeGetZeroConstant(ir_context, transformation_context,
1251 component_type_id, false);
1252 }
1253
1254 if (component_id == 0) {
1255 return 0;
1256 }
1257
1258 return MaybeGetCompositeConstant(
1259 ir_context, transformation_context,
1260 std::vector<uint32_t>(GetArraySize(*type_inst, ir_context),
1261 component_id),
1262 scalar_or_composite_type_id, is_irrelevant);
1263 }
1264 default:
1265 assert(false && "Type is not supported");
1266 return 0;
1267 }
1268 }
1269
CanCreateConstant(opt::IRContext * ir_context,uint32_t type_id)1270 bool CanCreateConstant(opt::IRContext* ir_context, uint32_t type_id) {
1271 opt::Instruction* type_instr = ir_context->get_def_use_mgr()->GetDef(type_id);
1272 assert(type_instr != nullptr && "The type must exist.");
1273 assert(spvOpcodeGeneratesType(type_instr->opcode()) &&
1274 "A type-generating opcode was expected.");
1275 switch (type_instr->opcode()) {
1276 case SpvOpTypeBool:
1277 case SpvOpTypeInt:
1278 case SpvOpTypeFloat:
1279 case SpvOpTypeMatrix:
1280 case SpvOpTypeVector:
1281 return true;
1282 case SpvOpTypeArray:
1283 return CanCreateConstant(ir_context,
1284 type_instr->GetSingleWordInOperand(0));
1285 case SpvOpTypeStruct:
1286 if (HasBlockOrBufferBlockDecoration(ir_context, type_id)) {
1287 return false;
1288 }
1289 for (uint32_t index = 0; index < type_instr->NumInOperands(); index++) {
1290 if (!CanCreateConstant(ir_context,
1291 type_instr->GetSingleWordInOperand(index))) {
1292 return false;
1293 }
1294 }
1295 return true;
1296 default:
1297 return false;
1298 }
1299 }
1300
MaybeGetScalarConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t scalar_type_id,bool is_irrelevant)1301 uint32_t MaybeGetScalarConstant(
1302 opt::IRContext* ir_context,
1303 const TransformationContext& transformation_context,
1304 const std::vector<uint32_t>& words, uint32_t scalar_type_id,
1305 bool is_irrelevant) {
1306 const auto* type = ir_context->get_type_mgr()->GetType(scalar_type_id);
1307 assert(type && "|scalar_type_id| is invalid");
1308
1309 if (const auto* int_type = type->AsInteger()) {
1310 return MaybeGetIntegerConstant(ir_context, transformation_context, words,
1311 int_type->width(), int_type->IsSigned(),
1312 is_irrelevant);
1313 } else if (const auto* float_type = type->AsFloat()) {
1314 return MaybeGetFloatConstant(ir_context, transformation_context, words,
1315 float_type->width(), is_irrelevant);
1316 } else {
1317 assert(type->AsBool() && words.size() == 1 &&
1318 "|scalar_type_id| doesn't represent a scalar type");
1319 return MaybeGetBoolConstant(ir_context, transformation_context, words[0],
1320 is_irrelevant);
1321 }
1322 }
1323
MaybeGetCompositeConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & component_ids,uint32_t composite_type_id,bool is_irrelevant)1324 uint32_t MaybeGetCompositeConstant(
1325 opt::IRContext* ir_context,
1326 const TransformationContext& transformation_context,
1327 const std::vector<uint32_t>& component_ids, uint32_t composite_type_id,
1328 bool is_irrelevant) {
1329 const auto* type = ir_context->get_type_mgr()->GetType(composite_type_id);
1330 (void)type; // Make compilers happy in release mode.
1331 assert(IsCompositeType(type) && "|composite_type_id| is invalid");
1332
1333 for (const auto& inst : ir_context->types_values()) {
1334 if (inst.opcode() == SpvOpConstantComposite &&
1335 inst.type_id() == composite_type_id &&
1336 transformation_context.GetFactManager()->IdIsIrrelevant(
1337 inst.result_id()) == is_irrelevant &&
1338 inst.NumInOperands() == component_ids.size()) {
1339 bool is_match = true;
1340
1341 for (uint32_t i = 0; i < inst.NumInOperands(); ++i) {
1342 if (inst.GetSingleWordInOperand(i) != component_ids[i]) {
1343 is_match = false;
1344 break;
1345 }
1346 }
1347
1348 if (is_match) {
1349 return inst.result_id();
1350 }
1351 }
1352 }
1353
1354 return 0;
1355 }
1356
MaybeGetIntegerConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_signed,bool is_irrelevant)1357 uint32_t MaybeGetIntegerConstant(
1358 opt::IRContext* ir_context,
1359 const TransformationContext& transformation_context,
1360 const std::vector<uint32_t>& words, uint32_t width, bool is_signed,
1361 bool is_irrelevant) {
1362 if (auto type_id = MaybeGetIntegerType(ir_context, width, is_signed)) {
1363 return MaybeGetOpConstant(ir_context, transformation_context, words,
1364 type_id, is_irrelevant);
1365 }
1366
1367 return 0;
1368 }
1369
MaybeGetIntegerConstantFromValueAndType(opt::IRContext * ir_context,uint32_t value,uint32_t int_type_id)1370 uint32_t MaybeGetIntegerConstantFromValueAndType(opt::IRContext* ir_context,
1371 uint32_t value,
1372 uint32_t int_type_id) {
1373 auto int_type_inst = ir_context->get_def_use_mgr()->GetDef(int_type_id);
1374
1375 assert(int_type_inst && "The given type id must exist.");
1376
1377 auto int_type = ir_context->get_type_mgr()
1378 ->GetType(int_type_inst->result_id())
1379 ->AsInteger();
1380
1381 assert(int_type && int_type->width() == 32 &&
1382 "The given type id must correspond to an 32-bit integer type.");
1383
1384 opt::analysis::IntConstant constant(int_type, {value});
1385
1386 // Check that the constant exists in the module.
1387 if (!ir_context->get_constant_mgr()->FindConstant(&constant)) {
1388 return 0;
1389 }
1390
1391 return ir_context->get_constant_mgr()
1392 ->GetDefiningInstruction(&constant)
1393 ->result_id();
1394 }
1395
MaybeGetFloatConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,const std::vector<uint32_t> & words,uint32_t width,bool is_irrelevant)1396 uint32_t MaybeGetFloatConstant(
1397 opt::IRContext* ir_context,
1398 const TransformationContext& transformation_context,
1399 const std::vector<uint32_t>& words, uint32_t width, bool is_irrelevant) {
1400 if (auto type_id = MaybeGetFloatType(ir_context, width)) {
1401 return MaybeGetOpConstant(ir_context, transformation_context, words,
1402 type_id, is_irrelevant);
1403 }
1404
1405 return 0;
1406 }
1407
MaybeGetBoolConstant(opt::IRContext * ir_context,const TransformationContext & transformation_context,bool value,bool is_irrelevant)1408 uint32_t MaybeGetBoolConstant(
1409 opt::IRContext* ir_context,
1410 const TransformationContext& transformation_context, bool value,
1411 bool is_irrelevant) {
1412 if (auto type_id = MaybeGetBoolType(ir_context)) {
1413 for (const auto& inst : ir_context->types_values()) {
1414 if (inst.opcode() == (value ? SpvOpConstantTrue : SpvOpConstantFalse) &&
1415 inst.type_id() == type_id &&
1416 transformation_context.GetFactManager()->IdIsIrrelevant(
1417 inst.result_id()) == is_irrelevant) {
1418 return inst.result_id();
1419 }
1420 }
1421 }
1422
1423 return 0;
1424 }
1425
IntToWords(uint64_t value,uint32_t width,bool is_signed)1426 std::vector<uint32_t> IntToWords(uint64_t value, uint32_t width,
1427 bool is_signed) {
1428 assert(width <= 64 && "The bit width should not be more than 64 bits");
1429
1430 // Sign-extend or zero-extend the last |width| bits of |value|, depending on
1431 // |is_signed|.
1432 if (is_signed) {
1433 // Sign-extend by shifting left and then shifting right, interpreting the
1434 // integer as signed.
1435 value = static_cast<int64_t>(value << (64 - width)) >> (64 - width);
1436 } else {
1437 // Zero-extend by shifting left and then shifting right, interpreting the
1438 // integer as unsigned.
1439 value = (value << (64 - width)) >> (64 - width);
1440 }
1441
1442 std::vector<uint32_t> result;
1443 result.push_back(static_cast<uint32_t>(value));
1444 if (width > 32) {
1445 result.push_back(static_cast<uint32_t>(value >> 32));
1446 }
1447 return result;
1448 }
1449
TypesAreEqualUpToSign(opt::IRContext * ir_context,uint32_t type1_id,uint32_t type2_id)1450 bool TypesAreEqualUpToSign(opt::IRContext* ir_context, uint32_t type1_id,
1451 uint32_t type2_id) {
1452 if (type1_id == type2_id) {
1453 return true;
1454 }
1455
1456 auto type1 = ir_context->get_type_mgr()->GetType(type1_id);
1457 auto type2 = ir_context->get_type_mgr()->GetType(type2_id);
1458
1459 // Integer scalar types must have the same width
1460 if (type1->AsInteger() && type2->AsInteger()) {
1461 return type1->AsInteger()->width() == type2->AsInteger()->width();
1462 }
1463
1464 // Integer vector types must have the same number of components and their
1465 // component types must be integers with the same width.
1466 if (type1->AsVector() && type2->AsVector()) {
1467 auto component_type1 = type1->AsVector()->element_type()->AsInteger();
1468 auto component_type2 = type2->AsVector()->element_type()->AsInteger();
1469
1470 // Only check the component count and width if they are integer.
1471 if (component_type1 && component_type2) {
1472 return type1->AsVector()->element_count() ==
1473 type2->AsVector()->element_count() &&
1474 component_type1->width() == component_type2->width();
1475 }
1476 }
1477
1478 // In all other cases, the types cannot be considered equal.
1479 return false;
1480 }
1481
RepeatedUInt32PairToMap(const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> & data)1482 std::map<uint32_t, uint32_t> RepeatedUInt32PairToMap(
1483 const google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>& data) {
1484 std::map<uint32_t, uint32_t> result;
1485
1486 for (const auto& entry : data) {
1487 result[entry.first()] = entry.second();
1488 }
1489
1490 return result;
1491 }
1492
1493 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair>
MapToRepeatedUInt32Pair(const std::map<uint32_t,uint32_t> & data)1494 MapToRepeatedUInt32Pair(const std::map<uint32_t, uint32_t>& data) {
1495 google::protobuf::RepeatedPtrField<protobufs::UInt32Pair> result;
1496
1497 for (const auto& entry : data) {
1498 protobufs::UInt32Pair pair;
1499 pair.set_first(entry.first);
1500 pair.set_second(entry.second);
1501 *result.Add() = std::move(pair);
1502 }
1503
1504 return result;
1505 }
1506
GetLastInsertBeforeInstruction(opt::IRContext * ir_context,uint32_t block_id,SpvOp opcode)1507 opt::Instruction* GetLastInsertBeforeInstruction(opt::IRContext* ir_context,
1508 uint32_t block_id,
1509 SpvOp opcode) {
1510 // CFG::block uses std::map::at which throws an exception when |block_id| is
1511 // invalid. The error message is unhelpful, though. Thus, we test that
1512 // |block_id| is valid here.
1513 const auto* label_inst = ir_context->get_def_use_mgr()->GetDef(block_id);
1514 (void)label_inst; // Make compilers happy in release mode.
1515 assert(label_inst && label_inst->opcode() == SpvOpLabel &&
1516 "|block_id| is invalid");
1517
1518 auto* block = ir_context->cfg()->block(block_id);
1519 auto it = block->rbegin();
1520 assert(it != block->rend() && "Basic block can't be empty");
1521
1522 if (block->GetMergeInst()) {
1523 ++it;
1524 assert(it != block->rend() &&
1525 "|block| must have at least two instructions:"
1526 "terminator and a merge instruction");
1527 }
1528
1529 return CanInsertOpcodeBeforeInstruction(opcode, &*it) ? &*it : nullptr;
1530 }
1531
IdUseCanBeReplaced(opt::IRContext * ir_context,const TransformationContext & transformation_context,opt::Instruction * use_instruction,uint32_t use_in_operand_index)1532 bool IdUseCanBeReplaced(opt::IRContext* ir_context,
1533 const TransformationContext& transformation_context,
1534 opt::Instruction* use_instruction,
1535 uint32_t use_in_operand_index) {
1536 if (spvOpcodeIsAccessChain(use_instruction->opcode()) &&
1537 use_in_operand_index > 0) {
1538 // A replacement for an irrelevant index in OpAccessChain must be clamped
1539 // first.
1540 if (transformation_context.GetFactManager()->IdIsIrrelevant(
1541 use_instruction->GetSingleWordInOperand(use_in_operand_index))) {
1542 return false;
1543 }
1544
1545 // This is an access chain index. If the (sub-)object being accessed by the
1546 // given index has struct type then we cannot replace the use, as it needs
1547 // to be an OpConstant.
1548
1549 // Get the top-level composite type that is being accessed.
1550 auto object_being_accessed = ir_context->get_def_use_mgr()->GetDef(
1551 use_instruction->GetSingleWordInOperand(0));
1552 auto pointer_type =
1553 ir_context->get_type_mgr()->GetType(object_being_accessed->type_id());
1554 assert(pointer_type->AsPointer());
1555 auto composite_type_being_accessed =
1556 pointer_type->AsPointer()->pointee_type();
1557
1558 // Now walk the access chain, tracking the type of each sub-object of the
1559 // composite that is traversed, until the index of interest is reached.
1560 for (uint32_t index_in_operand = 1; index_in_operand < use_in_operand_index;
1561 index_in_operand++) {
1562 // For vectors, matrices and arrays, getting the type of the sub-object is
1563 // trivial. For the struct case, the sub-object type is field-sensitive,
1564 // and depends on the constant index that is used.
1565 if (composite_type_being_accessed->AsVector()) {
1566 composite_type_being_accessed =
1567 composite_type_being_accessed->AsVector()->element_type();
1568 } else if (composite_type_being_accessed->AsMatrix()) {
1569 composite_type_being_accessed =
1570 composite_type_being_accessed->AsMatrix()->element_type();
1571 } else if (composite_type_being_accessed->AsArray()) {
1572 composite_type_being_accessed =
1573 composite_type_being_accessed->AsArray()->element_type();
1574 } else if (composite_type_being_accessed->AsRuntimeArray()) {
1575 composite_type_being_accessed =
1576 composite_type_being_accessed->AsRuntimeArray()->element_type();
1577 } else {
1578 assert(composite_type_being_accessed->AsStruct());
1579 auto constant_index_instruction = ir_context->get_def_use_mgr()->GetDef(
1580 use_instruction->GetSingleWordInOperand(index_in_operand));
1581 assert(constant_index_instruction->opcode() == SpvOpConstant);
1582 uint32_t member_index =
1583 constant_index_instruction->GetSingleWordInOperand(0);
1584 composite_type_being_accessed =
1585 composite_type_being_accessed->AsStruct()
1586 ->element_types()[member_index];
1587 }
1588 }
1589
1590 // We have found the composite type being accessed by the index we are
1591 // considering replacing. If it is a struct, then we cannot do the
1592 // replacement as struct indices must be constants.
1593 if (composite_type_being_accessed->AsStruct()) {
1594 return false;
1595 }
1596 }
1597
1598 if (use_instruction->opcode() == SpvOpFunctionCall &&
1599 use_in_operand_index > 0) {
1600 // This is a function call argument. It is not allowed to have pointer
1601 // type.
1602
1603 // Get the definition of the function being called.
1604 auto function = ir_context->get_def_use_mgr()->GetDef(
1605 use_instruction->GetSingleWordInOperand(0));
1606 // From the function definition, get the function type.
1607 auto function_type = ir_context->get_def_use_mgr()->GetDef(
1608 function->GetSingleWordInOperand(1));
1609 // OpTypeFunction's 0-th input operand is the function return type, and the
1610 // function argument types follow. Because the arguments to OpFunctionCall
1611 // start from input operand 1, we can use |use_in_operand_index| to get the
1612 // type associated with this function argument.
1613 auto parameter_type = ir_context->get_type_mgr()->GetType(
1614 function_type->GetSingleWordInOperand(use_in_operand_index));
1615 if (parameter_type->AsPointer()) {
1616 return false;
1617 }
1618 }
1619
1620 if (use_instruction->opcode() == SpvOpImageTexelPointer &&
1621 use_in_operand_index == 2) {
1622 // The OpImageTexelPointer instruction has a Sample parameter that in some
1623 // situations must be an id for the value 0. To guard against disrupting
1624 // that requirement, we do not replace this argument to that instruction.
1625 return false;
1626 }
1627
1628 return true;
1629 }
1630
MembersHaveBuiltInDecoration(opt::IRContext * ir_context,uint32_t struct_type_id)1631 bool MembersHaveBuiltInDecoration(opt::IRContext* ir_context,
1632 uint32_t struct_type_id) {
1633 const auto* type_inst = ir_context->get_def_use_mgr()->GetDef(struct_type_id);
1634 assert(type_inst && type_inst->opcode() == SpvOpTypeStruct &&
1635 "|struct_type_id| is not a result id of an OpTypeStruct");
1636
1637 uint32_t builtin_count = 0;
1638 ir_context->get_def_use_mgr()->ForEachUser(
1639 type_inst,
1640 [struct_type_id, &builtin_count](const opt::Instruction* user) {
1641 if (user->opcode() == SpvOpMemberDecorate &&
1642 user->GetSingleWordInOperand(0) == struct_type_id &&
1643 static_cast<SpvDecoration>(user->GetSingleWordInOperand(2)) ==
1644 SpvDecorationBuiltIn) {
1645 ++builtin_count;
1646 }
1647 });
1648
1649 assert((builtin_count == 0 || builtin_count == type_inst->NumInOperands()) &&
1650 "The module is invalid: either none or all of the members of "
1651 "|struct_type_id| may be builtin");
1652
1653 return builtin_count != 0;
1654 }
1655
HasBlockOrBufferBlockDecoration(opt::IRContext * ir_context,uint32_t id)1656 bool HasBlockOrBufferBlockDecoration(opt::IRContext* ir_context, uint32_t id) {
1657 for (auto decoration : {SpvDecorationBlock, SpvDecorationBufferBlock}) {
1658 if (!ir_context->get_decoration_mgr()->WhileEachDecoration(
1659 id, decoration, [](const opt::Instruction & /*unused*/) -> bool {
1660 return false;
1661 })) {
1662 return true;
1663 }
1664 }
1665 return false;
1666 }
1667
SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(opt::BasicBlock * block_to_split,opt::Instruction * split_before)1668 bool SplittingBeforeInstructionSeparatesOpSampledImageDefinitionFromUse(
1669 opt::BasicBlock* block_to_split, opt::Instruction* split_before) {
1670 std::set<uint32_t> sampled_image_result_ids;
1671 bool before_split = true;
1672
1673 // Check all the instructions in the block to split.
1674 for (auto& instruction : *block_to_split) {
1675 if (&instruction == &*split_before) {
1676 before_split = false;
1677 }
1678 if (before_split) {
1679 // If the instruction comes before the split and its opcode is
1680 // OpSampledImage, record its result id.
1681 if (instruction.opcode() == SpvOpSampledImage) {
1682 sampled_image_result_ids.insert(instruction.result_id());
1683 }
1684 } else {
1685 // If the instruction comes after the split, check if ids
1686 // corresponding to OpSampledImage instructions defined before the split
1687 // are used, and return true if they are.
1688 if (!instruction.WhileEachInId(
1689 [&sampled_image_result_ids](uint32_t* id) -> bool {
1690 return !sampled_image_result_ids.count(*id);
1691 })) {
1692 return true;
1693 }
1694 }
1695 }
1696
1697 // No usage that would be separated from the definition has been found.
1698 return false;
1699 }
1700
InstructionHasNoSideEffects(const opt::Instruction & instruction)1701 bool InstructionHasNoSideEffects(const opt::Instruction& instruction) {
1702 switch (instruction.opcode()) {
1703 case SpvOpUndef:
1704 case SpvOpAccessChain:
1705 case SpvOpInBoundsAccessChain:
1706 case SpvOpArrayLength:
1707 case SpvOpVectorExtractDynamic:
1708 case SpvOpVectorInsertDynamic:
1709 case SpvOpVectorShuffle:
1710 case SpvOpCompositeConstruct:
1711 case SpvOpCompositeExtract:
1712 case SpvOpCompositeInsert:
1713 case SpvOpCopyObject:
1714 case SpvOpTranspose:
1715 case SpvOpConvertFToU:
1716 case SpvOpConvertFToS:
1717 case SpvOpConvertSToF:
1718 case SpvOpConvertUToF:
1719 case SpvOpUConvert:
1720 case SpvOpSConvert:
1721 case SpvOpFConvert:
1722 case SpvOpQuantizeToF16:
1723 case SpvOpSatConvertSToU:
1724 case SpvOpSatConvertUToS:
1725 case SpvOpBitcast:
1726 case SpvOpSNegate:
1727 case SpvOpFNegate:
1728 case SpvOpIAdd:
1729 case SpvOpFAdd:
1730 case SpvOpISub:
1731 case SpvOpFSub:
1732 case SpvOpIMul:
1733 case SpvOpFMul:
1734 case SpvOpUDiv:
1735 case SpvOpSDiv:
1736 case SpvOpFDiv:
1737 case SpvOpUMod:
1738 case SpvOpSRem:
1739 case SpvOpSMod:
1740 case SpvOpFRem:
1741 case SpvOpFMod:
1742 case SpvOpVectorTimesScalar:
1743 case SpvOpMatrixTimesScalar:
1744 case SpvOpVectorTimesMatrix:
1745 case SpvOpMatrixTimesVector:
1746 case SpvOpMatrixTimesMatrix:
1747 case SpvOpOuterProduct:
1748 case SpvOpDot:
1749 case SpvOpIAddCarry:
1750 case SpvOpISubBorrow:
1751 case SpvOpUMulExtended:
1752 case SpvOpSMulExtended:
1753 case SpvOpAny:
1754 case SpvOpAll:
1755 case SpvOpIsNan:
1756 case SpvOpIsInf:
1757 case SpvOpIsFinite:
1758 case SpvOpIsNormal:
1759 case SpvOpSignBitSet:
1760 case SpvOpLessOrGreater:
1761 case SpvOpOrdered:
1762 case SpvOpUnordered:
1763 case SpvOpLogicalEqual:
1764 case SpvOpLogicalNotEqual:
1765 case SpvOpLogicalOr:
1766 case SpvOpLogicalAnd:
1767 case SpvOpLogicalNot:
1768 case SpvOpSelect:
1769 case SpvOpIEqual:
1770 case SpvOpINotEqual:
1771 case SpvOpUGreaterThan:
1772 case SpvOpSGreaterThan:
1773 case SpvOpUGreaterThanEqual:
1774 case SpvOpSGreaterThanEqual:
1775 case SpvOpULessThan:
1776 case SpvOpSLessThan:
1777 case SpvOpULessThanEqual:
1778 case SpvOpSLessThanEqual:
1779 case SpvOpFOrdEqual:
1780 case SpvOpFUnordEqual:
1781 case SpvOpFOrdNotEqual:
1782 case SpvOpFUnordNotEqual:
1783 case SpvOpFOrdLessThan:
1784 case SpvOpFUnordLessThan:
1785 case SpvOpFOrdGreaterThan:
1786 case SpvOpFUnordGreaterThan:
1787 case SpvOpFOrdLessThanEqual:
1788 case SpvOpFUnordLessThanEqual:
1789 case SpvOpFOrdGreaterThanEqual:
1790 case SpvOpFUnordGreaterThanEqual:
1791 case SpvOpShiftRightLogical:
1792 case SpvOpShiftRightArithmetic:
1793 case SpvOpShiftLeftLogical:
1794 case SpvOpBitwiseOr:
1795 case SpvOpBitwiseXor:
1796 case SpvOpBitwiseAnd:
1797 case SpvOpNot:
1798 case SpvOpBitFieldInsert:
1799 case SpvOpBitFieldSExtract:
1800 case SpvOpBitFieldUExtract:
1801 case SpvOpBitReverse:
1802 case SpvOpBitCount:
1803 case SpvOpCopyLogical:
1804 case SpvOpPhi:
1805 case SpvOpPtrEqual:
1806 case SpvOpPtrNotEqual:
1807 return true;
1808 default:
1809 return false;
1810 }
1811 }
1812
GetReachableReturnBlocks(opt::IRContext * ir_context,uint32_t function_id)1813 std::set<uint32_t> GetReachableReturnBlocks(opt::IRContext* ir_context,
1814 uint32_t function_id) {
1815 auto function = ir_context->GetFunction(function_id);
1816 assert(function && "The function |function_id| must exist.");
1817
1818 std::set<uint32_t> result;
1819
1820 ir_context->cfg()->ForEachBlockInPostOrder(function->entry().get(),
1821 [&result](opt::BasicBlock* block) {
1822 if (block->IsReturn()) {
1823 result.emplace(block->id());
1824 }
1825 });
1826
1827 return result;
1828 }
1829
NewTerminatorPreservesDominationRules(opt::IRContext * ir_context,uint32_t block_id,opt::Instruction new_terminator)1830 bool NewTerminatorPreservesDominationRules(opt::IRContext* ir_context,
1831 uint32_t block_id,
1832 opt::Instruction new_terminator) {
1833 auto* mutated_block = MaybeFindBlock(ir_context, block_id);
1834 assert(mutated_block && "|block_id| is invalid");
1835
1836 ChangeTerminatorRAII change_terminator_raii(mutated_block,
1837 std::move(new_terminator));
1838 opt::DominatorAnalysis dominator_analysis;
1839 dominator_analysis.InitializeTree(*ir_context->cfg(),
1840 mutated_block->GetParent());
1841
1842 // Check that each dominator appears before each dominated block.
1843 std::unordered_map<uint32_t, size_t> positions;
1844 for (const auto& block : *mutated_block->GetParent()) {
1845 positions[block.id()] = positions.size();
1846 }
1847
1848 std::queue<uint32_t> q({mutated_block->GetParent()->begin()->id()});
1849 std::unordered_set<uint32_t> visited;
1850 while (!q.empty()) {
1851 auto block = q.front();
1852 q.pop();
1853 visited.insert(block);
1854
1855 auto success = ir_context->cfg()->block(block)->WhileEachSuccessorLabel(
1856 [&positions, &visited, &dominator_analysis, block, &q](uint32_t id) {
1857 if (id == block) {
1858 // Handle the case when loop header and continue target are the same
1859 // block.
1860 return true;
1861 }
1862
1863 if (dominator_analysis.Dominates(block, id) &&
1864 positions[block] > positions[id]) {
1865 // |block| dominates |id| but appears after |id| - violates
1866 // domination rules.
1867 return false;
1868 }
1869
1870 if (!visited.count(id)) {
1871 q.push(id);
1872 }
1873
1874 return true;
1875 });
1876
1877 if (!success) {
1878 return false;
1879 }
1880 }
1881
1882 // For each instruction in the |block->GetParent()| function check whether
1883 // all its dependencies satisfy domination rules (i.e. all id operands
1884 // dominate that instruction).
1885 for (const auto& block : *mutated_block->GetParent()) {
1886 if (!dominator_analysis.IsReachable(&block)) {
1887 // If some block is not reachable then we don't need to worry about the
1888 // preservation of domination rules for its instructions.
1889 continue;
1890 }
1891
1892 for (const auto& inst : block) {
1893 for (uint32_t i = 0; i < inst.NumInOperands();
1894 i += inst.opcode() == SpvOpPhi ? 2 : 1) {
1895 const auto& operand = inst.GetInOperand(i);
1896 if (!spvIsInIdType(operand.type)) {
1897 continue;
1898 }
1899
1900 if (MaybeFindBlock(ir_context, operand.words[0])) {
1901 // Ignore operands that refer to OpLabel instructions.
1902 continue;
1903 }
1904
1905 const auto* dependency_block =
1906 ir_context->get_instr_block(operand.words[0]);
1907 if (!dependency_block) {
1908 // A global instruction always dominates all instructions in any
1909 // function.
1910 continue;
1911 }
1912
1913 auto domination_target_id = inst.opcode() == SpvOpPhi
1914 ? inst.GetSingleWordInOperand(i + 1)
1915 : block.id();
1916
1917 if (!dominator_analysis.Dominates(dependency_block->id(),
1918 domination_target_id)) {
1919 return false;
1920 }
1921 }
1922 }
1923 }
1924
1925 return true;
1926 }
1927
1928 } // namespace fuzzerutil
1929 } // namespace fuzz
1930 } // namespace spvtools
1931