1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_add_function.h"
16
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_message.h"
19
20 namespace spvtools {
21 namespace fuzz {
22
TransformationAddFunction(const spvtools::fuzz::protobufs::TransformationAddFunction & message)23 TransformationAddFunction::TransformationAddFunction(
24 const spvtools::fuzz::protobufs::TransformationAddFunction& message)
25 : message_(message) {}
26
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions)27 TransformationAddFunction::TransformationAddFunction(
28 const std::vector<protobufs::Instruction>& instructions) {
29 for (auto& instruction : instructions) {
30 *message_.add_instruction() = instruction;
31 }
32 message_.set_is_livesafe(false);
33 }
34
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions,uint32_t loop_limiter_variable_id,uint32_t loop_limit_constant_id,const std::vector<protobufs::LoopLimiterInfo> & loop_limiters,uint32_t kill_unreachable_return_value_id,const std::vector<protobufs::AccessChainClampingInfo> & access_chain_clampers)35 TransformationAddFunction::TransformationAddFunction(
36 const std::vector<protobufs::Instruction>& instructions,
37 uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
38 const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
39 uint32_t kill_unreachable_return_value_id,
40 const std::vector<protobufs::AccessChainClampingInfo>&
41 access_chain_clampers) {
42 for (auto& instruction : instructions) {
43 *message_.add_instruction() = instruction;
44 }
45 message_.set_is_livesafe(true);
46 message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
47 message_.set_loop_limit_constant_id(loop_limit_constant_id);
48 for (auto& loop_limiter : loop_limiters) {
49 *message_.add_loop_limiter_info() = loop_limiter;
50 }
51 message_.set_kill_unreachable_return_value_id(
52 kill_unreachable_return_value_id);
53 for (auto& access_clamper : access_chain_clampers) {
54 *message_.add_access_chain_clamping_info() = access_clamper;
55 }
56 }
57
IsApplicable(opt::IRContext * context,const spvtools::fuzz::FactManager & fact_manager) const58 bool TransformationAddFunction::IsApplicable(
59 opt::IRContext* context,
60 const spvtools::fuzz::FactManager& fact_manager) const {
61 // This transformation may use a lot of ids, all of which need to be fresh
62 // and distinct. This set tracks them.
63 std::set<uint32_t> ids_used_by_this_transformation;
64
65 // Ensure that all result ids in the new function are fresh and distinct.
66 for (auto& instruction : message_.instruction()) {
67 if (instruction.result_id()) {
68 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
69 instruction.result_id(), context,
70 &ids_used_by_this_transformation)) {
71 return false;
72 }
73 }
74 }
75
76 if (message_.is_livesafe()) {
77 // Ensure that all ids provided for making the function livesafe are fresh
78 // and distinct.
79 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
80 message_.loop_limiter_variable_id(), context,
81 &ids_used_by_this_transformation)) {
82 return false;
83 }
84 for (auto& loop_limiter_info : message_.loop_limiter_info()) {
85 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
86 loop_limiter_info.load_id(), context,
87 &ids_used_by_this_transformation)) {
88 return false;
89 }
90 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
91 loop_limiter_info.increment_id(), context,
92 &ids_used_by_this_transformation)) {
93 return false;
94 }
95 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
96 loop_limiter_info.compare_id(), context,
97 &ids_used_by_this_transformation)) {
98 return false;
99 }
100 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
101 loop_limiter_info.logical_op_id(), context,
102 &ids_used_by_this_transformation)) {
103 return false;
104 }
105 }
106 for (auto& access_chain_clamping_info :
107 message_.access_chain_clamping_info()) {
108 for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
109 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
110 pair.first(), context, &ids_used_by_this_transformation)) {
111 return false;
112 }
113 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
114 pair.second(), context, &ids_used_by_this_transformation)) {
115 return false;
116 }
117 }
118 }
119 }
120
121 // Because checking all the conditions for a function to be valid is a big
122 // job that the SPIR-V validator can already do, a "try it and see" approach
123 // is taken here.
124
125 // We first clone the current module, so that we can try adding the new
126 // function without risking wrecking |context|.
127 auto cloned_module = fuzzerutil::CloneIRContext(context);
128
129 // We try to add a function to the cloned module, which may fail if
130 // |message_.instruction| is not sufficiently well-formed.
131 if (!TryToAddFunction(cloned_module.get())) {
132 return false;
133 }
134
135 // Check whether the cloned module is still valid after adding the function.
136 // If it is not, the transformation is not applicable.
137 if (!fuzzerutil::IsValid(cloned_module.get())) {
138 return false;
139 }
140
141 if (message_.is_livesafe()) {
142 if (!TryToMakeFunctionLivesafe(cloned_module.get(), fact_manager)) {
143 return false;
144 }
145 // After making the function livesafe, we check validity of the module
146 // again. This is because the turning of OpKill, OpUnreachable and OpReturn
147 // instructions into branches changes control flow graph reachability, which
148 // has the potential to make the module invalid when it was otherwise valid.
149 // It is simpler to rely on the validator to guard against this than to
150 // consider all scenarios when making a function livesafe.
151 if (!fuzzerutil::IsValid(cloned_module.get())) {
152 return false;
153 }
154 }
155 return true;
156 }
157
Apply(opt::IRContext * context,spvtools::fuzz::FactManager * fact_manager) const158 void TransformationAddFunction::Apply(
159 opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const {
160 // Add the function to the module. As the transformation is applicable, this
161 // should succeed.
162 bool success = TryToAddFunction(context);
163 assert(success && "The function should be successfully added.");
164 (void)(success); // Keep release builds happy (otherwise they may complain
165 // that |success| is not used).
166
167 // Record the fact that all pointer parameters and variables declared in the
168 // function should be regarded as having irrelevant values. This allows other
169 // passes to store arbitrarily to such variables, and to pass them freely as
170 // parameters to other functions knowing that it is OK if they get
171 // over-written.
172 for (auto& instruction : message_.instruction()) {
173 switch (instruction.opcode()) {
174 case SpvOpFunctionParameter:
175 if (context->get_def_use_mgr()
176 ->GetDef(instruction.result_type_id())
177 ->opcode() == SpvOpTypePointer) {
178 fact_manager->AddFactValueOfPointeeIsIrrelevant(
179 instruction.result_id());
180 }
181 break;
182 case SpvOpVariable:
183 fact_manager->AddFactValueOfPointeeIsIrrelevant(
184 instruction.result_id());
185 break;
186 default:
187 break;
188 }
189 }
190
191 if (message_.is_livesafe()) {
192 // Make the function livesafe, which also should succeed.
193 success = TryToMakeFunctionLivesafe(context, *fact_manager);
194 assert(success && "It should be possible to make the function livesafe.");
195 (void)(success); // Keep release builds happy.
196
197 // Inform the fact manager that the function is livesafe.
198 assert(message_.instruction(0).opcode() == SpvOpFunction &&
199 "The first instruction of an 'add function' transformation must be "
200 "OpFunction.");
201 fact_manager->AddFactFunctionIsLivesafe(
202 message_.instruction(0).result_id());
203 } else {
204 // Inform the fact manager that all blocks in the function are dead.
205 for (auto& inst : message_.instruction()) {
206 if (inst.opcode() == SpvOpLabel) {
207 fact_manager->AddFactBlockIsDead(inst.result_id());
208 }
209 }
210 }
211 context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
212 }
213
ToMessage() const214 protobufs::Transformation TransformationAddFunction::ToMessage() const {
215 protobufs::Transformation result;
216 *result.mutable_add_function() = message_;
217 return result;
218 }
219
TryToAddFunction(opt::IRContext * context) const220 bool TransformationAddFunction::TryToAddFunction(
221 opt::IRContext* context) const {
222 // This function returns false if |message_.instruction| was not well-formed
223 // enough to actually create a function and add it to |context|.
224
225 // A function must have at least some instructions.
226 if (message_.instruction().empty()) {
227 return false;
228 }
229
230 // A function must start with OpFunction.
231 auto function_begin = message_.instruction(0);
232 if (function_begin.opcode() != SpvOpFunction) {
233 return false;
234 }
235
236 // Make a function, headed by the OpFunction instruction.
237 std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
238 InstructionFromMessage(context, function_begin));
239
240 // Keeps track of which instruction protobuf message we are currently
241 // considering.
242 uint32_t instruction_index = 1;
243 const auto num_instructions =
244 static_cast<uint32_t>(message_.instruction().size());
245
246 // Iterate through all function parameter instructions, adding parameters to
247 // the new function.
248 while (instruction_index < num_instructions &&
249 message_.instruction(instruction_index).opcode() ==
250 SpvOpFunctionParameter) {
251 new_function->AddParameter(InstructionFromMessage(
252 context, message_.instruction(instruction_index)));
253 instruction_index++;
254 }
255
256 // After the parameters, there needs to be a label.
257 if (instruction_index == num_instructions ||
258 message_.instruction(instruction_index).opcode() != SpvOpLabel) {
259 return false;
260 }
261
262 // Iterate through the instructions block by block until the end of the
263 // function is reached.
264 while (instruction_index < num_instructions &&
265 message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
266 // Invariant: we should always be at a label instruction at this point.
267 assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
268
269 // Make a basic block using the label instruction, with the new function
270 // as its parent.
271 std::unique_ptr<opt::BasicBlock> block =
272 MakeUnique<opt::BasicBlock>(InstructionFromMessage(
273 context, message_.instruction(instruction_index)));
274 block->SetParent(new_function.get());
275
276 // Consider successive instructions until we hit another label or the end
277 // of the function, adding each such instruction to the block.
278 instruction_index++;
279 while (instruction_index < num_instructions &&
280 message_.instruction(instruction_index).opcode() !=
281 SpvOpFunctionEnd &&
282 message_.instruction(instruction_index).opcode() != SpvOpLabel) {
283 block->AddInstruction(InstructionFromMessage(
284 context, message_.instruction(instruction_index)));
285 instruction_index++;
286 }
287 // Add the block to the new function.
288 new_function->AddBasicBlock(std::move(block));
289 }
290 // Having considered all the blocks, we should be at the last instruction and
291 // it needs to be OpFunctionEnd.
292 if (instruction_index != num_instructions - 1 ||
293 message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
294 return false;
295 }
296 // Set the function's final instruction, add the function to the module and
297 // report success.
298 new_function->SetFunctionEnd(
299 InstructionFromMessage(context, message_.instruction(instruction_index)));
300 context->AddFunction(std::move(new_function));
301
302 context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
303
304 return true;
305 }
306
TryToMakeFunctionLivesafe(opt::IRContext * context,const FactManager & fact_manager) const307 bool TransformationAddFunction::TryToMakeFunctionLivesafe(
308 opt::IRContext* context, const FactManager& fact_manager) const {
309 assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
310
311 // Get a pointer to the added function.
312 opt::Function* added_function = nullptr;
313 for (auto& function : *context->module()) {
314 if (function.result_id() == message_.instruction(0).result_id()) {
315 added_function = &function;
316 break;
317 }
318 }
319 assert(added_function && "The added function should have been found.");
320
321 if (!TryToAddLoopLimiters(context, added_function)) {
322 // Adding loop limiters did not work; bail out.
323 return false;
324 }
325
326 // Consider all the instructions in the function, and:
327 // - attempt to replace OpKill and OpUnreachable with return instructions
328 // - attempt to clamp access chains to be within bounds
329 // - check that OpFunctionCall instructions are only to livesafe functions
330 for (auto& block : *added_function) {
331 for (auto& inst : block) {
332 switch (inst.opcode()) {
333 case SpvOpKill:
334 case SpvOpUnreachable:
335 if (!TryToTurnKillOrUnreachableIntoReturn(context, added_function,
336 &inst)) {
337 return false;
338 }
339 break;
340 case SpvOpAccessChain:
341 case SpvOpInBoundsAccessChain:
342 if (!TryToClampAccessChainIndices(context, &inst)) {
343 return false;
344 }
345 break;
346 case SpvOpFunctionCall:
347 // A livesafe function my only call other livesafe functions.
348 if (!fact_manager.FunctionIsLivesafe(
349 inst.GetSingleWordInOperand(0))) {
350 return false;
351 }
352 default:
353 break;
354 }
355 }
356 }
357 return true;
358 }
359
TryToAddLoopLimiters(opt::IRContext * context,opt::Function * added_function) const360 bool TransformationAddFunction::TryToAddLoopLimiters(
361 opt::IRContext* context, opt::Function* added_function) const {
362 // Collect up all the loop headers so that we can subsequently add loop
363 // limiting logic.
364 std::vector<opt::BasicBlock*> loop_headers;
365 for (auto& block : *added_function) {
366 if (block.IsLoopHeader()) {
367 loop_headers.push_back(&block);
368 }
369 }
370
371 if (loop_headers.empty()) {
372 // There are no loops, so no need to add any loop limiters.
373 return true;
374 }
375
376 // Check that the module contains appropriate ingredients for declaring and
377 // manipulating a loop limiter.
378
379 auto loop_limit_constant_id_instr =
380 context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
381 if (!loop_limit_constant_id_instr ||
382 loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
383 // The loop limit constant id instruction must exist and have an
384 // appropriate opcode.
385 return false;
386 }
387
388 auto loop_limit_type = context->get_def_use_mgr()->GetDef(
389 loop_limit_constant_id_instr->type_id());
390 if (loop_limit_type->opcode() != SpvOpTypeInt ||
391 loop_limit_type->GetSingleWordInOperand(0) != 32) {
392 // The type of the loop limit constant must be 32-bit integer. It
393 // doesn't actually matter whether the integer is signed or not.
394 return false;
395 }
396
397 // Find the id of the "unsigned int" type.
398 opt::analysis::Integer unsigned_int_type(32, false);
399 uint32_t unsigned_int_type_id =
400 context->get_type_mgr()->GetId(&unsigned_int_type);
401 if (!unsigned_int_type_id) {
402 // Unsigned int is not available; we need this type in order to add loop
403 // limiters.
404 return false;
405 }
406 auto registered_unsigned_int_type =
407 context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
408
409 // Look for 0 of type unsigned int.
410 opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
411 {0});
412 auto registered_zero = context->get_constant_mgr()->FindConstant(&zero);
413 if (!registered_zero) {
414 // We need 0 in order to be able to initialize loop limiters.
415 return false;
416 }
417 uint32_t zero_id = context->get_constant_mgr()
418 ->GetDefiningInstruction(registered_zero)
419 ->result_id();
420
421 // Look for 1 of type unsigned int.
422 opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
423 {1});
424 auto registered_one = context->get_constant_mgr()->FindConstant(&one);
425 if (!registered_one) {
426 // We need 1 in order to be able to increment loop limiters.
427 return false;
428 }
429 uint32_t one_id = context->get_constant_mgr()
430 ->GetDefiningInstruction(registered_one)
431 ->result_id();
432
433 // Look for pointer-to-unsigned int type.
434 opt::analysis::Pointer pointer_to_unsigned_int_type(
435 registered_unsigned_int_type, SpvStorageClassFunction);
436 uint32_t pointer_to_unsigned_int_type_id =
437 context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
438 if (!pointer_to_unsigned_int_type_id) {
439 // We need pointer-to-unsigned int in order to declare the loop limiter
440 // variable.
441 return false;
442 }
443
444 // Look for bool type.
445 opt::analysis::Bool bool_type;
446 uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
447 if (!bool_type_id) {
448 // We need bool in order to compare the loop limiter's value with the loop
449 // limit constant.
450 return false;
451 }
452
453 // Declare the loop limiter variable at the start of the function's entry
454 // block, via an instruction of the form:
455 // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
456 added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
457 context, SpvOpVariable, pointer_to_unsigned_int_type_id,
458 message_.loop_limiter_variable_id(),
459 opt::Instruction::OperandList(
460 {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
461 {SPV_OPERAND_TYPE_ID, {zero_id}}})));
462 // Update the module's id bound since we have added the loop limiter
463 // variable id.
464 fuzzerutil::UpdateModuleIdBound(context, message_.loop_limiter_variable_id());
465
466 // Consider each loop in turn.
467 for (auto loop_header : loop_headers) {
468 // Look for the loop's back-edge block. This is a predecessor of the loop
469 // header that is dominated by the loop header.
470 uint32_t back_edge_block_id = 0;
471 for (auto pred : context->cfg()->preds(loop_header->id())) {
472 if (context->GetDominatorAnalysis(added_function)
473 ->Dominates(loop_header->id(), pred)) {
474 back_edge_block_id = pred;
475 break;
476 }
477 }
478 if (!back_edge_block_id) {
479 // The loop's back-edge block must be unreachable. This means that the
480 // loop cannot iterate, so there is no need to make it lifesafe; we can
481 // move on from this loop.
482 continue;
483 }
484 auto back_edge_block = context->cfg()->block(back_edge_block_id);
485
486 // Go through the sequence of loop limiter infos and find the one
487 // corresponding to this loop.
488 bool found = false;
489 protobufs::LoopLimiterInfo loop_limiter_info;
490 for (auto& info : message_.loop_limiter_info()) {
491 if (info.loop_header_id() == loop_header->id()) {
492 loop_limiter_info = info;
493 found = true;
494 break;
495 }
496 }
497 if (!found) {
498 // We don't have loop limiter info for this loop header.
499 return false;
500 }
501
502 // The back-edge block either has the form:
503 //
504 // (1)
505 //
506 // %l = OpLabel
507 // ... instructions ...
508 // OpBranch %loop_header
509 //
510 // (2)
511 //
512 // %l = OpLabel
513 // ... instructions ...
514 // OpBranchConditional %c %loop_header %loop_merge
515 //
516 // (3)
517 //
518 // %l = OpLabel
519 // ... instructions ...
520 // OpBranchConditional %c %loop_merge %loop_header
521 //
522 // We turn these into the following:
523 //
524 // (1)
525 //
526 // %l = OpLabel
527 // ... instructions ...
528 // %t1 = OpLoad %uint32 %loop_limiter
529 // %t2 = OpIAdd %uint32 %t1 %one
530 // OpStore %loop_limiter %t2
531 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
532 // OpBranchConditional %t3 %loop_merge %loop_header
533 //
534 // (2)
535 //
536 // %l = OpLabel
537 // ... instructions ...
538 // %t1 = OpLoad %uint32 %loop_limiter
539 // %t2 = OpIAdd %uint32 %t1 %one
540 // OpStore %loop_limiter %t2
541 // %t3 = OpULessThan %bool %t1 %loop_limit
542 // %t4 = OpLogicalAnd %bool %c %t3
543 // OpBranchConditional %t4 %loop_header %loop_merge
544 //
545 // (3)
546 //
547 // %l = OpLabel
548 // ... instructions ...
549 // %t1 = OpLoad %uint32 %loop_limiter
550 // %t2 = OpIAdd %uint32 %t1 %one
551 // OpStore %loop_limiter %t2
552 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
553 // %t4 = OpLogicalOr %bool %c %t3
554 // OpBranchConditional %t4 %loop_merge %loop_header
555
556 auto back_edge_block_terminator = back_edge_block->terminator();
557 bool compare_using_greater_than_equal;
558 if (back_edge_block_terminator->opcode() == SpvOpBranch) {
559 compare_using_greater_than_equal = true;
560 } else {
561 assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
562 assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
563 loop_header->id() &&
564 back_edge_block_terminator->GetSingleWordInOperand(2) ==
565 loop_header->MergeBlockId()) ||
566 (back_edge_block_terminator->GetSingleWordInOperand(2) ==
567 loop_header->id() &&
568 back_edge_block_terminator->GetSingleWordInOperand(1) ==
569 loop_header->MergeBlockId())) &&
570 "A back edge edge block must branch to"
571 " either the loop header or merge");
572 compare_using_greater_than_equal =
573 back_edge_block_terminator->GetSingleWordInOperand(1) ==
574 loop_header->MergeBlockId();
575 }
576
577 std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
578
579 // Add a load from the loop limiter variable, of the form:
580 // %t1 = OpLoad %uint32 %loop_limiter
581 new_instructions.push_back(MakeUnique<opt::Instruction>(
582 context, SpvOpLoad, unsigned_int_type_id, loop_limiter_info.load_id(),
583 opt::Instruction::OperandList(
584 {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
585
586 // Increment the loaded value:
587 // %t2 = OpIAdd %uint32 %t1 %one
588 new_instructions.push_back(MakeUnique<opt::Instruction>(
589 context, SpvOpIAdd, unsigned_int_type_id,
590 loop_limiter_info.increment_id(),
591 opt::Instruction::OperandList(
592 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
593 {SPV_OPERAND_TYPE_ID, {one_id}}})));
594
595 // Store the incremented value back to the loop limiter variable:
596 // OpStore %loop_limiter %t2
597 new_instructions.push_back(MakeUnique<opt::Instruction>(
598 context, SpvOpStore, 0, 0,
599 opt::Instruction::OperandList(
600 {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
601 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
602
603 // Compare the loaded value with the loop limit; either:
604 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
605 // or
606 // %t3 = OpULessThan %bool %t1 %loop_limit
607 new_instructions.push_back(MakeUnique<opt::Instruction>(
608 context,
609 compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
610 : SpvOpULessThan,
611 bool_type_id, loop_limiter_info.compare_id(),
612 opt::Instruction::OperandList(
613 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
614 {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
615
616 if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
617 new_instructions.push_back(MakeUnique<opt::Instruction>(
618 context,
619 compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
620 bool_type_id, loop_limiter_info.logical_op_id(),
621 opt::Instruction::OperandList(
622 {{SPV_OPERAND_TYPE_ID,
623 {back_edge_block_terminator->GetSingleWordInOperand(0)}},
624 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
625 }
626
627 // Add the new instructions at the end of the back edge block, before the
628 // terminator and any loop merge instruction (as the back edge block can
629 // be the loop header).
630 if (back_edge_block->GetLoopMergeInst()) {
631 back_edge_block->GetLoopMergeInst()->InsertBefore(
632 std::move(new_instructions));
633 } else {
634 back_edge_block_terminator->InsertBefore(std::move(new_instructions));
635 }
636
637 if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
638 back_edge_block_terminator->SetInOperand(
639 0, {loop_limiter_info.logical_op_id()});
640 } else {
641 assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
642 "Back-edge terminator must be OpBranch or OpBranchConditional");
643
644 // Check that, if the merge block starts with OpPhi instructions, suitable
645 // ids have been provided to give these instructions a value corresponding
646 // to the new incoming edge from the back edge block.
647 auto merge_block = context->cfg()->block(loop_header->MergeBlockId());
648 if (!fuzzerutil::PhiIdsOkForNewEdge(context, back_edge_block, merge_block,
649 loop_limiter_info.phi_id())) {
650 return false;
651 }
652
653 // Augment OpPhi instructions at the loop merge with the given ids.
654 uint32_t phi_index = 0;
655 for (auto& inst : *merge_block) {
656 if (inst.opcode() != SpvOpPhi) {
657 break;
658 }
659 assert(phi_index <
660 static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
661 "There should be at least one phi id per OpPhi instruction.");
662 inst.AddOperand(
663 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
664 inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
665 phi_index++;
666 }
667
668 // Add the new edge, by changing OpBranch to OpBranchConditional.
669 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3162): This
670 // could be a problem if the merge block was originally unreachable: it
671 // might now be dominated by other blocks that it appears earlier than in
672 // the module.
673 back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
674 back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
675 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
676 {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}
677
678 },
679 {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
680 }
681
682 // Update the module's id bound with respect to the various ids that
683 // have been used for loop limiter manipulation.
684 fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.load_id());
685 fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.increment_id());
686 fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.compare_id());
687 fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.logical_op_id());
688 }
689 return true;
690 }
691
TryToTurnKillOrUnreachableIntoReturn(opt::IRContext * context,opt::Function * added_function,opt::Instruction * kill_or_unreachable_inst) const692 bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
693 opt::IRContext* context, opt::Function* added_function,
694 opt::Instruction* kill_or_unreachable_inst) const {
695 assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
696 kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
697 "Precondition: instruction must be OpKill or OpUnreachable.");
698
699 // Get the function's return type.
700 auto function_return_type_inst =
701 context->get_def_use_mgr()->GetDef(added_function->type_id());
702
703 if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
704 // The function has void return type, so change this instruction to
705 // OpReturn.
706 kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
707 } else {
708 // The function has non-void return type, so change this instruction
709 // to OpReturnValue, using the value id provided with the
710 // transformation.
711
712 // We first check that the id, %id, provided with the transformation
713 // specifically to turn OpKill and OpUnreachable instructions into
714 // OpReturnValue %id has the same type as the function's return type.
715 if (context->get_def_use_mgr()
716 ->GetDef(message_.kill_unreachable_return_value_id())
717 ->type_id() != function_return_type_inst->result_id()) {
718 return false;
719 }
720 kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
721 kill_or_unreachable_inst->SetInOperands(
722 {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
723 }
724 return true;
725 }
726
TryToClampAccessChainIndices(opt::IRContext * context,opt::Instruction * access_chain_inst) const727 bool TransformationAddFunction::TryToClampAccessChainIndices(
728 opt::IRContext* context, opt::Instruction* access_chain_inst) const {
729 assert((access_chain_inst->opcode() == SpvOpAccessChain ||
730 access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
731 "Precondition: instruction must be OpAccessChain or "
732 "OpInBoundsAccessChain.");
733
734 // Find the AccessChainClampingInfo associated with this access chain.
735 const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
736 nullptr;
737 for (auto& clamping_info : message_.access_chain_clamping_info()) {
738 if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
739 access_chain_clamping_info = &clamping_info;
740 break;
741 }
742 }
743 if (!access_chain_clamping_info) {
744 // No access chain clamping information was found; the function cannot be
745 // made livesafe.
746 return false;
747 }
748
749 // Check that there is a (compare_id, select_id) pair for every
750 // index associated with the instruction.
751 if (static_cast<uint32_t>(
752 access_chain_clamping_info->compare_and_select_ids().size()) !=
753 access_chain_inst->NumInOperands() - 1) {
754 return false;
755 }
756
757 // Walk the access chain, clamping each index to be within bounds if it is
758 // not a constant.
759 auto base_object = context->get_def_use_mgr()->GetDef(
760 access_chain_inst->GetSingleWordInOperand(0));
761 assert(base_object && "The base object must exist.");
762 auto pointer_type =
763 context->get_def_use_mgr()->GetDef(base_object->type_id());
764 assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
765 "The base object must have pointer type.");
766 auto should_be_composite_type = context->get_def_use_mgr()->GetDef(
767 pointer_type->GetSingleWordInOperand(1));
768
769 // Consider each index input operand in turn (operand 0 is the base object).
770 for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
771 index++) {
772 // We are going to turn:
773 //
774 // %result = OpAccessChain %type %object ... %index ...
775 //
776 // into:
777 //
778 // %t1 = OpULessThanEqual %bool %index %bound_minus_one
779 // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
780 // %result = OpAccessChain %type %object ... %t2 ...
781 //
782 // ... unless %index is already a constant.
783
784 // Get the bound for the composite being indexed into; e.g. the number of
785 // columns of matrix or the size of an array.
786 uint32_t bound =
787 GetBoundForCompositeIndex(context, *should_be_composite_type);
788
789 // Get the instruction associated with the index and figure out its integer
790 // type.
791 const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
792 auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
793 auto index_type_inst =
794 context->get_def_use_mgr()->GetDef(index_inst->type_id());
795 assert(index_type_inst->opcode() == SpvOpTypeInt);
796 assert(index_type_inst->GetSingleWordInOperand(0) == 32);
797 opt::analysis::Integer* index_int_type =
798 context->get_type_mgr()
799 ->GetType(index_type_inst->result_id())
800 ->AsInteger();
801
802 if (index_inst->opcode() != SpvOpConstant) {
803 // The index is non-constant so we need to clamp it.
804 assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
805 "Access chain indices into structures are required to be "
806 "constants.");
807 opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
808 if (!context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
809 // We do not have an integer constant whose value is |bound| -1.
810 return false;
811 }
812
813 opt::analysis::Bool bool_type;
814 uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
815 if (!bool_type_id) {
816 // Bool type is not declared; we cannot do a comparison.
817 return false;
818 }
819
820 uint32_t bound_minus_one_id =
821 context->get_constant_mgr()
822 ->GetDefiningInstruction(&bound_minus_one)
823 ->result_id();
824
825 uint32_t compare_id =
826 access_chain_clamping_info->compare_and_select_ids(index - 1).first();
827 uint32_t select_id =
828 access_chain_clamping_info->compare_and_select_ids(index - 1)
829 .second();
830 std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
831
832 // Compare the index with the bound via an instruction of the form:
833 // %t1 = OpULessThanEqual %bool %index %bound_minus_one
834 new_instructions.push_back(MakeUnique<opt::Instruction>(
835 context, SpvOpULessThanEqual, bool_type_id, compare_id,
836 opt::Instruction::OperandList(
837 {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
838 {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
839
840 // Select the index if in-bounds, otherwise one less than the bound:
841 // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
842 new_instructions.push_back(MakeUnique<opt::Instruction>(
843 context, SpvOpSelect, index_type_inst->result_id(), select_id,
844 opt::Instruction::OperandList(
845 {{SPV_OPERAND_TYPE_ID, {compare_id}},
846 {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
847 {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
848
849 // Add the new instructions before the access chain
850 access_chain_inst->InsertBefore(std::move(new_instructions));
851
852 // Replace %index with %t2.
853 access_chain_inst->SetInOperand(index, {select_id});
854 fuzzerutil::UpdateModuleIdBound(context, compare_id);
855 fuzzerutil::UpdateModuleIdBound(context, select_id);
856 } else {
857 // TODO(afd): At present the SPIR-V spec is not clear on whether
858 // statically out-of-bounds indices mean that a module is invalid (so
859 // that it should be rejected by the validator), or that such accesses
860 // yield undefined results. Via the following assertion, we assume that
861 // functions added to the module do not feature statically out-of-bounds
862 // accesses.
863 // Assert that the index is smaller (unsigned) than this value.
864 // Return false if it is not (to keep compilers happy).
865 if (index_inst->GetSingleWordInOperand(0) >= bound) {
866 assert(false &&
867 "The function has a statically out-of-bounds access; "
868 "this should not occur.");
869 return false;
870 }
871 }
872 should_be_composite_type =
873 FollowCompositeIndex(context, *should_be_composite_type, index_id);
874 }
875 return true;
876 }
877
GetBoundForCompositeIndex(opt::IRContext * context,const opt::Instruction & composite_type_inst)878 uint32_t TransformationAddFunction::GetBoundForCompositeIndex(
879 opt::IRContext* context, const opt::Instruction& composite_type_inst) {
880 switch (composite_type_inst.opcode()) {
881 case SpvOpTypeArray:
882 return fuzzerutil::GetArraySize(composite_type_inst, context);
883 case SpvOpTypeMatrix:
884 case SpvOpTypeVector:
885 return composite_type_inst.GetSingleWordInOperand(1);
886 case SpvOpTypeStruct: {
887 return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
888 }
889 default:
890 assert(false && "Unknown composite type.");
891 return 0;
892 }
893 }
894
FollowCompositeIndex(opt::IRContext * context,const opt::Instruction & composite_type_inst,uint32_t index_id)895 opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
896 opt::IRContext* context, const opt::Instruction& composite_type_inst,
897 uint32_t index_id) {
898 uint32_t sub_object_type_id;
899 switch (composite_type_inst.opcode()) {
900 case SpvOpTypeArray:
901 sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
902 break;
903 case SpvOpTypeMatrix:
904 case SpvOpTypeVector:
905 sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
906 break;
907 case SpvOpTypeStruct: {
908 auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
909 assert(index_inst->opcode() == SpvOpConstant);
910 assert(
911 context->get_def_use_mgr()->GetDef(index_inst->type_id())->opcode() ==
912 SpvOpTypeInt);
913 assert(context->get_def_use_mgr()
914 ->GetDef(index_inst->type_id())
915 ->GetSingleWordInOperand(0) == 32);
916 uint32_t index_value = index_inst->GetSingleWordInOperand(0);
917 sub_object_type_id =
918 composite_type_inst.GetSingleWordInOperand(index_value);
919 break;
920 }
921 default:
922 assert(false && "Unknown composite type.");
923 sub_object_type_id = 0;
924 break;
925 }
926 assert(sub_object_type_id && "No sub-object found.");
927 return context->get_def_use_mgr()->GetDef(sub_object_type_id);
928 }
929
930 } // namespace fuzz
931 } // namespace spvtools
932