1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_add_function.h"
16
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_message.h"
19
20 namespace spvtools {
21 namespace fuzz {
22
TransformationAddFunction(protobufs::TransformationAddFunction message)23 TransformationAddFunction::TransformationAddFunction(
24 protobufs::TransformationAddFunction message)
25 : message_(std::move(message)) {}
26
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions)27 TransformationAddFunction::TransformationAddFunction(
28 const std::vector<protobufs::Instruction>& instructions) {
29 for (auto& instruction : instructions) {
30 *message_.add_instruction() = instruction;
31 }
32 message_.set_is_livesafe(false);
33 }
34
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions,uint32_t loop_limiter_variable_id,uint32_t loop_limit_constant_id,const std::vector<protobufs::LoopLimiterInfo> & loop_limiters,uint32_t kill_unreachable_return_value_id,const std::vector<protobufs::AccessChainClampingInfo> & access_chain_clampers)35 TransformationAddFunction::TransformationAddFunction(
36 const std::vector<protobufs::Instruction>& instructions,
37 uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
38 const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
39 uint32_t kill_unreachable_return_value_id,
40 const std::vector<protobufs::AccessChainClampingInfo>&
41 access_chain_clampers) {
42 for (auto& instruction : instructions) {
43 *message_.add_instruction() = instruction;
44 }
45 message_.set_is_livesafe(true);
46 message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
47 message_.set_loop_limit_constant_id(loop_limit_constant_id);
48 for (auto& loop_limiter : loop_limiters) {
49 *message_.add_loop_limiter_info() = loop_limiter;
50 }
51 message_.set_kill_unreachable_return_value_id(
52 kill_unreachable_return_value_id);
53 for (auto& access_clamper : access_chain_clampers) {
54 *message_.add_access_chain_clamping_info() = access_clamper;
55 }
56 }
57
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const58 bool TransformationAddFunction::IsApplicable(
59 opt::IRContext* ir_context,
60 const TransformationContext& transformation_context) const {
61 // This transformation may use a lot of ids, all of which need to be fresh
62 // and distinct. This set tracks them.
63 std::set<uint32_t> ids_used_by_this_transformation;
64
65 // Ensure that all result ids in the new function are fresh and distinct.
66 for (auto& instruction : message_.instruction()) {
67 if (instruction.result_id()) {
68 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
69 instruction.result_id(), ir_context,
70 &ids_used_by_this_transformation)) {
71 return false;
72 }
73 }
74 }
75
76 if (message_.is_livesafe()) {
77 // Ensure that all ids provided for making the function livesafe are fresh
78 // and distinct.
79 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
80 message_.loop_limiter_variable_id(), ir_context,
81 &ids_used_by_this_transformation)) {
82 return false;
83 }
84 for (auto& loop_limiter_info : message_.loop_limiter_info()) {
85 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
86 loop_limiter_info.load_id(), ir_context,
87 &ids_used_by_this_transformation)) {
88 return false;
89 }
90 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
91 loop_limiter_info.increment_id(), ir_context,
92 &ids_used_by_this_transformation)) {
93 return false;
94 }
95 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
96 loop_limiter_info.compare_id(), ir_context,
97 &ids_used_by_this_transformation)) {
98 return false;
99 }
100 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
101 loop_limiter_info.logical_op_id(), ir_context,
102 &ids_used_by_this_transformation)) {
103 return false;
104 }
105 }
106 for (auto& access_chain_clamping_info :
107 message_.access_chain_clamping_info()) {
108 for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
109 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
110 pair.first(), ir_context, &ids_used_by_this_transformation)) {
111 return false;
112 }
113 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
114 pair.second(), ir_context, &ids_used_by_this_transformation)) {
115 return false;
116 }
117 }
118 }
119 }
120
121 // Because checking all the conditions for a function to be valid is a big
122 // job that the SPIR-V validator can already do, a "try it and see" approach
123 // is taken here.
124
125 // We first clone the current module, so that we can try adding the new
126 // function without risking wrecking |ir_context|.
127 auto cloned_module = fuzzerutil::CloneIRContext(ir_context);
128
129 // We try to add a function to the cloned module, which may fail if
130 // |message_.instruction| is not sufficiently well-formed.
131 if (!TryToAddFunction(cloned_module.get())) {
132 return false;
133 }
134
135 // Check whether the cloned module is still valid after adding the function.
136 // If it is not, the transformation is not applicable.
137 if (!fuzzerutil::IsValid(cloned_module.get(),
138 transformation_context.GetValidatorOptions(),
139 fuzzerutil::kSilentMessageConsumer)) {
140 return false;
141 }
142
143 if (message_.is_livesafe()) {
144 if (!TryToMakeFunctionLivesafe(cloned_module.get(),
145 transformation_context)) {
146 return false;
147 }
148 // After making the function livesafe, we check validity of the module
149 // again. This is because the turning of OpKill, OpUnreachable and OpReturn
150 // instructions into branches changes control flow graph reachability, which
151 // has the potential to make the module invalid when it was otherwise valid.
152 // It is simpler to rely on the validator to guard against this than to
153 // consider all scenarios when making a function livesafe.
154 if (!fuzzerutil::IsValid(cloned_module.get(),
155 transformation_context.GetValidatorOptions(),
156 fuzzerutil::kSilentMessageConsumer)) {
157 return false;
158 }
159 }
160 return true;
161 }
162
Apply(opt::IRContext * ir_context,TransformationContext * transformation_context) const163 void TransformationAddFunction::Apply(
164 opt::IRContext* ir_context,
165 TransformationContext* transformation_context) const {
166 // Add the function to the module. As the transformation is applicable, this
167 // should succeed.
168 bool success = TryToAddFunction(ir_context);
169 assert(success && "The function should be successfully added.");
170 (void)(success); // Keep release builds happy (otherwise they may complain
171 // that |success| is not used).
172
173 if (message_.is_livesafe()) {
174 // Make the function livesafe, which also should succeed.
175 success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
176 assert(success && "It should be possible to make the function livesafe.");
177 (void)(success); // Keep release builds happy.
178 }
179 ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
180
181 assert(message_.instruction(0).opcode() == SpvOpFunction &&
182 "The first instruction of an 'add function' transformation must be "
183 "OpFunction.");
184
185 if (message_.is_livesafe()) {
186 // Inform the fact manager that the function is livesafe.
187 transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
188 message_.instruction(0).result_id());
189 } else {
190 // Inform the fact manager that all blocks in the function are dead.
191 for (auto& inst : message_.instruction()) {
192 if (inst.opcode() == SpvOpLabel) {
193 transformation_context->GetFactManager()->AddFactBlockIsDead(
194 inst.result_id());
195 }
196 }
197 }
198
199 // Record the fact that all pointer parameters and variables declared in the
200 // function should be regarded as having irrelevant values. This allows other
201 // passes to store arbitrarily to such variables, and to pass them freely as
202 // parameters to other functions knowing that it is OK if they get
203 // over-written.
204 for (auto& instruction : message_.instruction()) {
205 switch (instruction.opcode()) {
206 case SpvOpFunctionParameter:
207 if (ir_context->get_def_use_mgr()
208 ->GetDef(instruction.result_type_id())
209 ->opcode() == SpvOpTypePointer) {
210 transformation_context->GetFactManager()
211 ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
212 }
213 break;
214 case SpvOpVariable:
215 transformation_context->GetFactManager()
216 ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
217 break;
218 default:
219 break;
220 }
221 }
222 }
223
ToMessage() const224 protobufs::Transformation TransformationAddFunction::ToMessage() const {
225 protobufs::Transformation result;
226 *result.mutable_add_function() = message_;
227 return result;
228 }
229
TryToAddFunction(opt::IRContext * ir_context) const230 bool TransformationAddFunction::TryToAddFunction(
231 opt::IRContext* ir_context) const {
232 // This function returns false if |message_.instruction| was not well-formed
233 // enough to actually create a function and add it to |ir_context|.
234
235 // A function must have at least some instructions.
236 if (message_.instruction().empty()) {
237 return false;
238 }
239
240 // A function must start with OpFunction.
241 auto function_begin = message_.instruction(0);
242 if (function_begin.opcode() != SpvOpFunction) {
243 return false;
244 }
245
246 // Make a function, headed by the OpFunction instruction.
247 std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
248 InstructionFromMessage(ir_context, function_begin));
249
250 // Keeps track of which instruction protobuf message we are currently
251 // considering.
252 uint32_t instruction_index = 1;
253 const auto num_instructions =
254 static_cast<uint32_t>(message_.instruction().size());
255
256 // Iterate through all function parameter instructions, adding parameters to
257 // the new function.
258 while (instruction_index < num_instructions &&
259 message_.instruction(instruction_index).opcode() ==
260 SpvOpFunctionParameter) {
261 new_function->AddParameter(InstructionFromMessage(
262 ir_context, message_.instruction(instruction_index)));
263 instruction_index++;
264 }
265
266 // After the parameters, there needs to be a label.
267 if (instruction_index == num_instructions ||
268 message_.instruction(instruction_index).opcode() != SpvOpLabel) {
269 return false;
270 }
271
272 // Iterate through the instructions block by block until the end of the
273 // function is reached.
274 while (instruction_index < num_instructions &&
275 message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
276 // Invariant: we should always be at a label instruction at this point.
277 assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
278
279 // Make a basic block using the label instruction.
280 std::unique_ptr<opt::BasicBlock> block =
281 MakeUnique<opt::BasicBlock>(InstructionFromMessage(
282 ir_context, message_.instruction(instruction_index)));
283
284 // Consider successive instructions until we hit another label or the end
285 // of the function, adding each such instruction to the block.
286 instruction_index++;
287 while (instruction_index < num_instructions &&
288 message_.instruction(instruction_index).opcode() !=
289 SpvOpFunctionEnd &&
290 message_.instruction(instruction_index).opcode() != SpvOpLabel) {
291 block->AddInstruction(InstructionFromMessage(
292 ir_context, message_.instruction(instruction_index)));
293 instruction_index++;
294 }
295 // Add the block to the new function.
296 new_function->AddBasicBlock(std::move(block));
297 }
298 // Having considered all the blocks, we should be at the last instruction and
299 // it needs to be OpFunctionEnd.
300 if (instruction_index != num_instructions - 1 ||
301 message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
302 return false;
303 }
304 // Set the function's final instruction, add the function to the module and
305 // report success.
306 new_function->SetFunctionEnd(InstructionFromMessage(
307 ir_context, message_.instruction(instruction_index)));
308 ir_context->AddFunction(std::move(new_function));
309
310 ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
311
312 return true;
313 }
314
TryToMakeFunctionLivesafe(opt::IRContext * ir_context,const TransformationContext & transformation_context) const315 bool TransformationAddFunction::TryToMakeFunctionLivesafe(
316 opt::IRContext* ir_context,
317 const TransformationContext& transformation_context) const {
318 assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
319
320 // Get a pointer to the added function.
321 opt::Function* added_function = nullptr;
322 for (auto& function : *ir_context->module()) {
323 if (function.result_id() == message_.instruction(0).result_id()) {
324 added_function = &function;
325 break;
326 }
327 }
328 assert(added_function && "The added function should have been found.");
329
330 if (!TryToAddLoopLimiters(ir_context, added_function)) {
331 // Adding loop limiters did not work; bail out.
332 return false;
333 }
334
335 // Consider all the instructions in the function, and:
336 // - attempt to replace OpKill and OpUnreachable with return instructions
337 // - attempt to clamp access chains to be within bounds
338 // - check that OpFunctionCall instructions are only to livesafe functions
339 for (auto& block : *added_function) {
340 for (auto& inst : block) {
341 switch (inst.opcode()) {
342 case SpvOpKill:
343 case SpvOpUnreachable:
344 if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
345 &inst)) {
346 return false;
347 }
348 break;
349 case SpvOpAccessChain:
350 case SpvOpInBoundsAccessChain:
351 if (!TryToClampAccessChainIndices(ir_context, &inst)) {
352 return false;
353 }
354 break;
355 case SpvOpFunctionCall:
356 // A livesafe function my only call other livesafe functions.
357 if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
358 inst.GetSingleWordInOperand(0))) {
359 return false;
360 }
361 default:
362 break;
363 }
364 }
365 }
366 return true;
367 }
368
GetBackEdgeBlockId(opt::IRContext * ir_context,uint32_t loop_header_block_id)369 uint32_t TransformationAddFunction::GetBackEdgeBlockId(
370 opt::IRContext* ir_context, uint32_t loop_header_block_id) {
371 const auto* loop_header_block =
372 ir_context->cfg()->block(loop_header_block_id);
373 assert(loop_header_block && "|loop_header_block_id| is invalid");
374
375 for (auto pred : ir_context->cfg()->preds(loop_header_block_id)) {
376 if (ir_context->GetDominatorAnalysis(loop_header_block->GetParent())
377 ->Dominates(loop_header_block_id, pred)) {
378 return pred;
379 }
380 }
381
382 return 0;
383 }
384
TryToAddLoopLimiters(opt::IRContext * ir_context,opt::Function * added_function) const385 bool TransformationAddFunction::TryToAddLoopLimiters(
386 opt::IRContext* ir_context, opt::Function* added_function) const {
387 // Collect up all the loop headers so that we can subsequently add loop
388 // limiting logic.
389 std::vector<opt::BasicBlock*> loop_headers;
390 for (auto& block : *added_function) {
391 if (block.IsLoopHeader()) {
392 loop_headers.push_back(&block);
393 }
394 }
395
396 if (loop_headers.empty()) {
397 // There are no loops, so no need to add any loop limiters.
398 return true;
399 }
400
401 // Check that the module contains appropriate ingredients for declaring and
402 // manipulating a loop limiter.
403
404 auto loop_limit_constant_id_instr =
405 ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
406 if (!loop_limit_constant_id_instr ||
407 loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
408 // The loop limit constant id instruction must exist and have an
409 // appropriate opcode.
410 return false;
411 }
412
413 auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
414 loop_limit_constant_id_instr->type_id());
415 if (loop_limit_type->opcode() != SpvOpTypeInt ||
416 loop_limit_type->GetSingleWordInOperand(0) != 32) {
417 // The type of the loop limit constant must be 32-bit integer. It
418 // doesn't actually matter whether the integer is signed or not.
419 return false;
420 }
421
422 // Find the id of the "unsigned int" type.
423 opt::analysis::Integer unsigned_int_type(32, false);
424 uint32_t unsigned_int_type_id =
425 ir_context->get_type_mgr()->GetId(&unsigned_int_type);
426 if (!unsigned_int_type_id) {
427 // Unsigned int is not available; we need this type in order to add loop
428 // limiters.
429 return false;
430 }
431 auto registered_unsigned_int_type =
432 ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
433
434 // Look for 0 of type unsigned int.
435 opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
436 {0});
437 auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
438 if (!registered_zero) {
439 // We need 0 in order to be able to initialize loop limiters.
440 return false;
441 }
442 uint32_t zero_id = ir_context->get_constant_mgr()
443 ->GetDefiningInstruction(registered_zero)
444 ->result_id();
445
446 // Look for 1 of type unsigned int.
447 opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
448 {1});
449 auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
450 if (!registered_one) {
451 // We need 1 in order to be able to increment loop limiters.
452 return false;
453 }
454 uint32_t one_id = ir_context->get_constant_mgr()
455 ->GetDefiningInstruction(registered_one)
456 ->result_id();
457
458 // Look for pointer-to-unsigned int type.
459 opt::analysis::Pointer pointer_to_unsigned_int_type(
460 registered_unsigned_int_type, SpvStorageClassFunction);
461 uint32_t pointer_to_unsigned_int_type_id =
462 ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
463 if (!pointer_to_unsigned_int_type_id) {
464 // We need pointer-to-unsigned int in order to declare the loop limiter
465 // variable.
466 return false;
467 }
468
469 // Look for bool type.
470 opt::analysis::Bool bool_type;
471 uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
472 if (!bool_type_id) {
473 // We need bool in order to compare the loop limiter's value with the loop
474 // limit constant.
475 return false;
476 }
477
478 // Declare the loop limiter variable at the start of the function's entry
479 // block, via an instruction of the form:
480 // %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
481 added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
482 ir_context, SpvOpVariable, pointer_to_unsigned_int_type_id,
483 message_.loop_limiter_variable_id(),
484 opt::Instruction::OperandList(
485 {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
486 {SPV_OPERAND_TYPE_ID, {zero_id}}})));
487 // Update the module's id bound since we have added the loop limiter
488 // variable id.
489 fuzzerutil::UpdateModuleIdBound(ir_context,
490 message_.loop_limiter_variable_id());
491
492 // Consider each loop in turn.
493 for (auto loop_header : loop_headers) {
494 // Look for the loop's back-edge block. This is a predecessor of the loop
495 // header that is dominated by the loop header.
496 const auto back_edge_block_id =
497 GetBackEdgeBlockId(ir_context, loop_header->id());
498 if (!back_edge_block_id) {
499 // The loop's back-edge block must be unreachable. This means that the
500 // loop cannot iterate, so there is no need to make it lifesafe; we can
501 // move on from this loop.
502 continue;
503 }
504
505 // If the loop's merge block is unreachable, then there are no constraints
506 // on where the merge block appears in relation to the blocks of the loop.
507 // This means we need to be careful when adding a branch from the back-edge
508 // block to the merge block: the branch might make the loop merge reachable,
509 // and it might then be dominated by the loop header and possibly by other
510 // blocks in the loop. Since a block needs to appear before those blocks it
511 // strictly dominates, this could make the module invalid. To avoid this
512 // problem we bail out in the case where the loop header does not dominate
513 // the loop merge.
514 if (!ir_context->GetDominatorAnalysis(added_function)
515 ->Dominates(loop_header->id(), loop_header->MergeBlockId())) {
516 return false;
517 }
518
519 // Go through the sequence of loop limiter infos and find the one
520 // corresponding to this loop.
521 bool found = false;
522 protobufs::LoopLimiterInfo loop_limiter_info;
523 for (auto& info : message_.loop_limiter_info()) {
524 if (info.loop_header_id() == loop_header->id()) {
525 loop_limiter_info = info;
526 found = true;
527 break;
528 }
529 }
530 if (!found) {
531 // We don't have loop limiter info for this loop header.
532 return false;
533 }
534
535 // The back-edge block either has the form:
536 //
537 // (1)
538 //
539 // %l = OpLabel
540 // ... instructions ...
541 // OpBranch %loop_header
542 //
543 // (2)
544 //
545 // %l = OpLabel
546 // ... instructions ...
547 // OpBranchConditional %c %loop_header %loop_merge
548 //
549 // (3)
550 //
551 // %l = OpLabel
552 // ... instructions ...
553 // OpBranchConditional %c %loop_merge %loop_header
554 //
555 // We turn these into the following:
556 //
557 // (1)
558 //
559 // %l = OpLabel
560 // ... instructions ...
561 // %t1 = OpLoad %uint32 %loop_limiter
562 // %t2 = OpIAdd %uint32 %t1 %one
563 // OpStore %loop_limiter %t2
564 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
565 // OpBranchConditional %t3 %loop_merge %loop_header
566 //
567 // (2)
568 //
569 // %l = OpLabel
570 // ... instructions ...
571 // %t1 = OpLoad %uint32 %loop_limiter
572 // %t2 = OpIAdd %uint32 %t1 %one
573 // OpStore %loop_limiter %t2
574 // %t3 = OpULessThan %bool %t1 %loop_limit
575 // %t4 = OpLogicalAnd %bool %c %t3
576 // OpBranchConditional %t4 %loop_header %loop_merge
577 //
578 // (3)
579 //
580 // %l = OpLabel
581 // ... instructions ...
582 // %t1 = OpLoad %uint32 %loop_limiter
583 // %t2 = OpIAdd %uint32 %t1 %one
584 // OpStore %loop_limiter %t2
585 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
586 // %t4 = OpLogicalOr %bool %c %t3
587 // OpBranchConditional %t4 %loop_merge %loop_header
588
589 auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
590 auto back_edge_block_terminator = back_edge_block->terminator();
591 bool compare_using_greater_than_equal;
592 if (back_edge_block_terminator->opcode() == SpvOpBranch) {
593 compare_using_greater_than_equal = true;
594 } else {
595 assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
596 assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
597 loop_header->id() &&
598 back_edge_block_terminator->GetSingleWordInOperand(2) ==
599 loop_header->MergeBlockId()) ||
600 (back_edge_block_terminator->GetSingleWordInOperand(2) ==
601 loop_header->id() &&
602 back_edge_block_terminator->GetSingleWordInOperand(1) ==
603 loop_header->MergeBlockId())) &&
604 "A back edge edge block must branch to"
605 " either the loop header or merge");
606 compare_using_greater_than_equal =
607 back_edge_block_terminator->GetSingleWordInOperand(1) ==
608 loop_header->MergeBlockId();
609 }
610
611 std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
612
613 // Add a load from the loop limiter variable, of the form:
614 // %t1 = OpLoad %uint32 %loop_limiter
615 new_instructions.push_back(MakeUnique<opt::Instruction>(
616 ir_context, SpvOpLoad, unsigned_int_type_id,
617 loop_limiter_info.load_id(),
618 opt::Instruction::OperandList(
619 {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
620
621 // Increment the loaded value:
622 // %t2 = OpIAdd %uint32 %t1 %one
623 new_instructions.push_back(MakeUnique<opt::Instruction>(
624 ir_context, SpvOpIAdd, unsigned_int_type_id,
625 loop_limiter_info.increment_id(),
626 opt::Instruction::OperandList(
627 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
628 {SPV_OPERAND_TYPE_ID, {one_id}}})));
629
630 // Store the incremented value back to the loop limiter variable:
631 // OpStore %loop_limiter %t2
632 new_instructions.push_back(MakeUnique<opt::Instruction>(
633 ir_context, SpvOpStore, 0, 0,
634 opt::Instruction::OperandList(
635 {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
636 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
637
638 // Compare the loaded value with the loop limit; either:
639 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
640 // or
641 // %t3 = OpULessThan %bool %t1 %loop_limit
642 new_instructions.push_back(MakeUnique<opt::Instruction>(
643 ir_context,
644 compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
645 : SpvOpULessThan,
646 bool_type_id, loop_limiter_info.compare_id(),
647 opt::Instruction::OperandList(
648 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
649 {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
650
651 if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
652 new_instructions.push_back(MakeUnique<opt::Instruction>(
653 ir_context,
654 compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
655 bool_type_id, loop_limiter_info.logical_op_id(),
656 opt::Instruction::OperandList(
657 {{SPV_OPERAND_TYPE_ID,
658 {back_edge_block_terminator->GetSingleWordInOperand(0)}},
659 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
660 }
661
662 // Add the new instructions at the end of the back edge block, before the
663 // terminator and any loop merge instruction (as the back edge block can
664 // be the loop header).
665 if (back_edge_block->GetLoopMergeInst()) {
666 back_edge_block->GetLoopMergeInst()->InsertBefore(
667 std::move(new_instructions));
668 } else {
669 back_edge_block_terminator->InsertBefore(std::move(new_instructions));
670 }
671
672 if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
673 back_edge_block_terminator->SetInOperand(
674 0, {loop_limiter_info.logical_op_id()});
675 } else {
676 assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
677 "Back-edge terminator must be OpBranch or OpBranchConditional");
678
679 // Check that, if the merge block starts with OpPhi instructions, suitable
680 // ids have been provided to give these instructions a value corresponding
681 // to the new incoming edge from the back edge block.
682 auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
683 if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
684 merge_block,
685 loop_limiter_info.phi_id())) {
686 return false;
687 }
688
689 // Augment OpPhi instructions at the loop merge with the given ids.
690 uint32_t phi_index = 0;
691 for (auto& inst : *merge_block) {
692 if (inst.opcode() != SpvOpPhi) {
693 break;
694 }
695 assert(phi_index <
696 static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
697 "There should be at least one phi id per OpPhi instruction.");
698 inst.AddOperand(
699 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
700 inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
701 phi_index++;
702 }
703
704 // Add the new edge, by changing OpBranch to OpBranchConditional.
705 back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
706 back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
707 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
708 {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}},
709 {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
710 }
711
712 // Update the module's id bound with respect to the various ids that
713 // have been used for loop limiter manipulation.
714 fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
715 fuzzerutil::UpdateModuleIdBound(ir_context,
716 loop_limiter_info.increment_id());
717 fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
718 fuzzerutil::UpdateModuleIdBound(ir_context,
719 loop_limiter_info.logical_op_id());
720 }
721 return true;
722 }
723
TryToTurnKillOrUnreachableIntoReturn(opt::IRContext * ir_context,opt::Function * added_function,opt::Instruction * kill_or_unreachable_inst) const724 bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
725 opt::IRContext* ir_context, opt::Function* added_function,
726 opt::Instruction* kill_or_unreachable_inst) const {
727 assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
728 kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
729 "Precondition: instruction must be OpKill or OpUnreachable.");
730
731 // Get the function's return type.
732 auto function_return_type_inst =
733 ir_context->get_def_use_mgr()->GetDef(added_function->type_id());
734
735 if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
736 // The function has void return type, so change this instruction to
737 // OpReturn.
738 kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
739 } else {
740 // The function has non-void return type, so change this instruction
741 // to OpReturnValue, using the value id provided with the
742 // transformation.
743
744 // We first check that the id, %id, provided with the transformation
745 // specifically to turn OpKill and OpUnreachable instructions into
746 // OpReturnValue %id has the same type as the function's return type.
747 if (ir_context->get_def_use_mgr()
748 ->GetDef(message_.kill_unreachable_return_value_id())
749 ->type_id() != function_return_type_inst->result_id()) {
750 return false;
751 }
752 kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
753 kill_or_unreachable_inst->SetInOperands(
754 {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
755 }
756 return true;
757 }
758
TryToClampAccessChainIndices(opt::IRContext * ir_context,opt::Instruction * access_chain_inst) const759 bool TransformationAddFunction::TryToClampAccessChainIndices(
760 opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
761 assert((access_chain_inst->opcode() == SpvOpAccessChain ||
762 access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
763 "Precondition: instruction must be OpAccessChain or "
764 "OpInBoundsAccessChain.");
765
766 // Find the AccessChainClampingInfo associated with this access chain.
767 const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
768 nullptr;
769 for (auto& clamping_info : message_.access_chain_clamping_info()) {
770 if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
771 access_chain_clamping_info = &clamping_info;
772 break;
773 }
774 }
775 if (!access_chain_clamping_info) {
776 // No access chain clamping information was found; the function cannot be
777 // made livesafe.
778 return false;
779 }
780
781 // Check that there is a (compare_id, select_id) pair for every
782 // index associated with the instruction.
783 if (static_cast<uint32_t>(
784 access_chain_clamping_info->compare_and_select_ids().size()) !=
785 access_chain_inst->NumInOperands() - 1) {
786 return false;
787 }
788
789 // Walk the access chain, clamping each index to be within bounds if it is
790 // not a constant.
791 auto base_object = ir_context->get_def_use_mgr()->GetDef(
792 access_chain_inst->GetSingleWordInOperand(0));
793 assert(base_object && "The base object must exist.");
794 auto pointer_type =
795 ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
796 assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
797 "The base object must have pointer type.");
798 auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
799 pointer_type->GetSingleWordInOperand(1));
800
801 // Consider each index input operand in turn (operand 0 is the base object).
802 for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
803 index++) {
804 // We are going to turn:
805 //
806 // %result = OpAccessChain %type %object ... %index ...
807 //
808 // into:
809 //
810 // %t1 = OpULessThanEqual %bool %index %bound_minus_one
811 // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
812 // %result = OpAccessChain %type %object ... %t2 ...
813 //
814 // ... unless %index is already a constant.
815
816 // Get the bound for the composite being indexed into; e.g. the number of
817 // columns of matrix or the size of an array.
818 uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
819 *should_be_composite_type, ir_context);
820
821 // Get the instruction associated with the index and figure out its integer
822 // type.
823 const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
824 auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
825 auto index_type_inst =
826 ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
827 assert(index_type_inst->opcode() == SpvOpTypeInt);
828 assert(index_type_inst->GetSingleWordInOperand(0) == 32);
829 opt::analysis::Integer* index_int_type =
830 ir_context->get_type_mgr()
831 ->GetType(index_type_inst->result_id())
832 ->AsInteger();
833
834 if (index_inst->opcode() != SpvOpConstant ||
835 index_inst->GetSingleWordInOperand(0) >= bound) {
836 // The index is either non-constant or an out-of-bounds constant, so we
837 // need to clamp it.
838 assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
839 "Access chain indices into structures are required to be "
840 "constants.");
841 opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
842 if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
843 // We do not have an integer constant whose value is |bound| -1.
844 return false;
845 }
846
847 opt::analysis::Bool bool_type;
848 uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
849 if (!bool_type_id) {
850 // Bool type is not declared; we cannot do a comparison.
851 return false;
852 }
853
854 uint32_t bound_minus_one_id =
855 ir_context->get_constant_mgr()
856 ->GetDefiningInstruction(&bound_minus_one)
857 ->result_id();
858
859 uint32_t compare_id =
860 access_chain_clamping_info->compare_and_select_ids(index - 1).first();
861 uint32_t select_id =
862 access_chain_clamping_info->compare_and_select_ids(index - 1)
863 .second();
864 std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
865
866 // Compare the index with the bound via an instruction of the form:
867 // %t1 = OpULessThanEqual %bool %index %bound_minus_one
868 new_instructions.push_back(MakeUnique<opt::Instruction>(
869 ir_context, SpvOpULessThanEqual, bool_type_id, compare_id,
870 opt::Instruction::OperandList(
871 {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
872 {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
873
874 // Select the index if in-bounds, otherwise one less than the bound:
875 // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
876 new_instructions.push_back(MakeUnique<opt::Instruction>(
877 ir_context, SpvOpSelect, index_type_inst->result_id(), select_id,
878 opt::Instruction::OperandList(
879 {{SPV_OPERAND_TYPE_ID, {compare_id}},
880 {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
881 {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
882
883 // Add the new instructions before the access chain
884 access_chain_inst->InsertBefore(std::move(new_instructions));
885
886 // Replace %index with %t2.
887 access_chain_inst->SetInOperand(index, {select_id});
888 fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
889 fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
890 }
891 should_be_composite_type =
892 FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
893 }
894 return true;
895 }
896
FollowCompositeIndex(opt::IRContext * ir_context,const opt::Instruction & composite_type_inst,uint32_t index_id)897 opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
898 opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
899 uint32_t index_id) {
900 uint32_t sub_object_type_id;
901 switch (composite_type_inst.opcode()) {
902 case SpvOpTypeArray:
903 case SpvOpTypeRuntimeArray:
904 sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
905 break;
906 case SpvOpTypeMatrix:
907 case SpvOpTypeVector:
908 sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
909 break;
910 case SpvOpTypeStruct: {
911 auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
912 assert(index_inst->opcode() == SpvOpConstant);
913 assert(ir_context->get_def_use_mgr()
914 ->GetDef(index_inst->type_id())
915 ->opcode() == SpvOpTypeInt);
916 assert(ir_context->get_def_use_mgr()
917 ->GetDef(index_inst->type_id())
918 ->GetSingleWordInOperand(0) == 32);
919 uint32_t index_value = index_inst->GetSingleWordInOperand(0);
920 sub_object_type_id =
921 composite_type_inst.GetSingleWordInOperand(index_value);
922 break;
923 }
924 default:
925 assert(false && "Unknown composite type.");
926 sub_object_type_id = 0;
927 break;
928 }
929 assert(sub_object_type_id && "No sub-object found.");
930 return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
931 }
932
GetFreshIds() const933 std::unordered_set<uint32_t> TransformationAddFunction::GetFreshIds() const {
934 std::unordered_set<uint32_t> result;
935 for (auto& instruction : message_.instruction()) {
936 result.insert(instruction.result_id());
937 }
938 if (message_.is_livesafe()) {
939 result.insert(message_.loop_limiter_variable_id());
940 for (auto& loop_limiter_info : message_.loop_limiter_info()) {
941 result.insert(loop_limiter_info.load_id());
942 result.insert(loop_limiter_info.increment_id());
943 result.insert(loop_limiter_info.compare_id());
944 result.insert(loop_limiter_info.logical_op_id());
945 }
946 for (auto& access_chain_clamping_info :
947 message_.access_chain_clamping_info()) {
948 for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
949 result.insert(pair.first());
950 result.insert(pair.second());
951 }
952 }
953 }
954 return result;
955 }
956
957 } // namespace fuzz
958 } // namespace spvtools
959