• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_add_function.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_message.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 
TransformationAddFunction(protobufs::TransformationAddFunction message)23 TransformationAddFunction::TransformationAddFunction(
24     protobufs::TransformationAddFunction message)
25     : message_(std::move(message)) {}
26 
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions)27 TransformationAddFunction::TransformationAddFunction(
28     const std::vector<protobufs::Instruction>& instructions) {
29   for (auto& instruction : instructions) {
30     *message_.add_instruction() = instruction;
31   }
32   message_.set_is_livesafe(false);
33 }
34 
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions,uint32_t loop_limiter_variable_id,uint32_t loop_limit_constant_id,const std::vector<protobufs::LoopLimiterInfo> & loop_limiters,uint32_t kill_unreachable_return_value_id,const std::vector<protobufs::AccessChainClampingInfo> & access_chain_clampers)35 TransformationAddFunction::TransformationAddFunction(
36     const std::vector<protobufs::Instruction>& instructions,
37     uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
38     const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
39     uint32_t kill_unreachable_return_value_id,
40     const std::vector<protobufs::AccessChainClampingInfo>&
41         access_chain_clampers) {
42   for (auto& instruction : instructions) {
43     *message_.add_instruction() = instruction;
44   }
45   message_.set_is_livesafe(true);
46   message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
47   message_.set_loop_limit_constant_id(loop_limit_constant_id);
48   for (auto& loop_limiter : loop_limiters) {
49     *message_.add_loop_limiter_info() = loop_limiter;
50   }
51   message_.set_kill_unreachable_return_value_id(
52       kill_unreachable_return_value_id);
53   for (auto& access_clamper : access_chain_clampers) {
54     *message_.add_access_chain_clamping_info() = access_clamper;
55   }
56 }
57 
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const58 bool TransformationAddFunction::IsApplicable(
59     opt::IRContext* ir_context,
60     const TransformationContext& transformation_context) const {
61   // This transformation may use a lot of ids, all of which need to be fresh
62   // and distinct.  This set tracks them.
63   std::set<uint32_t> ids_used_by_this_transformation;
64 
65   // Ensure that all result ids in the new function are fresh and distinct.
66   for (auto& instruction : message_.instruction()) {
67     if (instruction.result_id()) {
68       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
69               instruction.result_id(), ir_context,
70               &ids_used_by_this_transformation)) {
71         return false;
72       }
73     }
74   }
75 
76   if (message_.is_livesafe()) {
77     // Ensure that all ids provided for making the function livesafe are fresh
78     // and distinct.
79     if (!CheckIdIsFreshAndNotUsedByThisTransformation(
80             message_.loop_limiter_variable_id(), ir_context,
81             &ids_used_by_this_transformation)) {
82       return false;
83     }
84     for (auto& loop_limiter_info : message_.loop_limiter_info()) {
85       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
86               loop_limiter_info.load_id(), ir_context,
87               &ids_used_by_this_transformation)) {
88         return false;
89       }
90       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
91               loop_limiter_info.increment_id(), ir_context,
92               &ids_used_by_this_transformation)) {
93         return false;
94       }
95       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
96               loop_limiter_info.compare_id(), ir_context,
97               &ids_used_by_this_transformation)) {
98         return false;
99       }
100       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
101               loop_limiter_info.logical_op_id(), ir_context,
102               &ids_used_by_this_transformation)) {
103         return false;
104       }
105     }
106     for (auto& access_chain_clamping_info :
107          message_.access_chain_clamping_info()) {
108       for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
109         if (!CheckIdIsFreshAndNotUsedByThisTransformation(
110                 pair.first(), ir_context, &ids_used_by_this_transformation)) {
111           return false;
112         }
113         if (!CheckIdIsFreshAndNotUsedByThisTransformation(
114                 pair.second(), ir_context, &ids_used_by_this_transformation)) {
115           return false;
116         }
117       }
118     }
119   }
120 
121   // Because checking all the conditions for a function to be valid is a big
122   // job that the SPIR-V validator can already do, a "try it and see" approach
123   // is taken here.
124 
125   // We first clone the current module, so that we can try adding the new
126   // function without risking wrecking |ir_context|.
127   auto cloned_module = fuzzerutil::CloneIRContext(ir_context);
128 
129   // We try to add a function to the cloned module, which may fail if
130   // |message_.instruction| is not sufficiently well-formed.
131   if (!TryToAddFunction(cloned_module.get())) {
132     return false;
133   }
134 
135   // Check whether the cloned module is still valid after adding the function.
136   // If it is not, the transformation is not applicable.
137   if (!fuzzerutil::IsValid(cloned_module.get(),
138                            transformation_context.GetValidatorOptions(),
139                            fuzzerutil::kSilentMessageConsumer)) {
140     return false;
141   }
142 
143   if (message_.is_livesafe()) {
144     if (!TryToMakeFunctionLivesafe(cloned_module.get(),
145                                    transformation_context)) {
146       return false;
147     }
148     // After making the function livesafe, we check validity of the module
149     // again.  This is because the turning of OpKill, OpUnreachable and OpReturn
150     // instructions into branches changes control flow graph reachability, which
151     // has the potential to make the module invalid when it was otherwise valid.
152     // It is simpler to rely on the validator to guard against this than to
153     // consider all scenarios when making a function livesafe.
154     if (!fuzzerutil::IsValid(cloned_module.get(),
155                              transformation_context.GetValidatorOptions(),
156                              fuzzerutil::kSilentMessageConsumer)) {
157       return false;
158     }
159   }
160   return true;
161 }
162 
Apply(opt::IRContext * ir_context,TransformationContext * transformation_context) const163 void TransformationAddFunction::Apply(
164     opt::IRContext* ir_context,
165     TransformationContext* transformation_context) const {
166   // Add the function to the module.  As the transformation is applicable, this
167   // should succeed.
168   bool success = TryToAddFunction(ir_context);
169   assert(success && "The function should be successfully added.");
170   (void)(success);  // Keep release builds happy (otherwise they may complain
171                     // that |success| is not used).
172 
173   if (message_.is_livesafe()) {
174     // Make the function livesafe, which also should succeed.
175     success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
176     assert(success && "It should be possible to make the function livesafe.");
177     (void)(success);  // Keep release builds happy.
178   }
179   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
180 
181   assert(spv::Op(message_.instruction(0).opcode()) == spv::Op::OpFunction &&
182          "The first instruction of an 'add function' transformation must be "
183          "OpFunction.");
184 
185   if (message_.is_livesafe()) {
186     // Inform the fact manager that the function is livesafe.
187     transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
188         message_.instruction(0).result_id());
189   } else {
190     // Inform the fact manager that all blocks in the function are dead.
191     for (auto& inst : message_.instruction()) {
192       if (spv::Op(inst.opcode()) == spv::Op::OpLabel) {
193         transformation_context->GetFactManager()->AddFactBlockIsDead(
194             inst.result_id());
195       }
196     }
197   }
198 
199   // Record the fact that all pointer parameters and variables declared in the
200   // function should be regarded as having irrelevant values.  This allows other
201   // passes to store arbitrarily to such variables, and to pass them freely as
202   // parameters to other functions knowing that it is OK if they get
203   // over-written.
204   for (auto& instruction : message_.instruction()) {
205     switch (spv::Op(instruction.opcode())) {
206       case spv::Op::OpFunctionParameter:
207         if (ir_context->get_def_use_mgr()
208                 ->GetDef(instruction.result_type_id())
209                 ->opcode() == spv::Op::OpTypePointer) {
210           transformation_context->GetFactManager()
211               ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
212         }
213         break;
214       case spv::Op::OpVariable:
215         transformation_context->GetFactManager()
216             ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
217         break;
218       default:
219         break;
220     }
221   }
222 }
223 
ToMessage() const224 protobufs::Transformation TransformationAddFunction::ToMessage() const {
225   protobufs::Transformation result;
226   *result.mutable_add_function() = message_;
227   return result;
228 }
229 
TryToAddFunction(opt::IRContext * ir_context) const230 bool TransformationAddFunction::TryToAddFunction(
231     opt::IRContext* ir_context) const {
232   // This function returns false if |message_.instruction| was not well-formed
233   // enough to actually create a function and add it to |ir_context|.
234 
235   // A function must have at least some instructions.
236   if (message_.instruction().empty()) {
237     return false;
238   }
239 
240   // A function must start with OpFunction.
241   auto function_begin = message_.instruction(0);
242   if (spv::Op(function_begin.opcode()) != spv::Op::OpFunction) {
243     return false;
244   }
245 
246   // Make a function, headed by the OpFunction instruction.
247   std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
248       InstructionFromMessage(ir_context, function_begin));
249 
250   // Keeps track of which instruction protobuf message we are currently
251   // considering.
252   uint32_t instruction_index = 1;
253   const auto num_instructions =
254       static_cast<uint32_t>(message_.instruction().size());
255 
256   // Iterate through all function parameter instructions, adding parameters to
257   // the new function.
258   while (instruction_index < num_instructions &&
259          spv::Op(message_.instruction(instruction_index).opcode()) ==
260              spv::Op::OpFunctionParameter) {
261     new_function->AddParameter(InstructionFromMessage(
262         ir_context, message_.instruction(instruction_index)));
263     instruction_index++;
264   }
265 
266   // After the parameters, there needs to be a label.
267   if (instruction_index == num_instructions ||
268       spv::Op(message_.instruction(instruction_index).opcode()) !=
269           spv::Op::OpLabel) {
270     return false;
271   }
272 
273   // Iterate through the instructions block by block until the end of the
274   // function is reached.
275   while (instruction_index < num_instructions &&
276          spv::Op(message_.instruction(instruction_index).opcode()) !=
277              spv::Op::OpFunctionEnd) {
278     // Invariant: we should always be at a label instruction at this point.
279     assert(spv::Op(message_.instruction(instruction_index).opcode()) ==
280            spv::Op::OpLabel);
281 
282     // Make a basic block using the label instruction.
283     std::unique_ptr<opt::BasicBlock> block =
284         MakeUnique<opt::BasicBlock>(InstructionFromMessage(
285             ir_context, message_.instruction(instruction_index)));
286 
287     // Consider successive instructions until we hit another label or the end
288     // of the function, adding each such instruction to the block.
289     instruction_index++;
290     while (instruction_index < num_instructions &&
291            spv::Op(message_.instruction(instruction_index).opcode()) !=
292                spv::Op::OpFunctionEnd &&
293            spv::Op(message_.instruction(instruction_index).opcode()) !=
294                spv::Op::OpLabel) {
295       block->AddInstruction(InstructionFromMessage(
296           ir_context, message_.instruction(instruction_index)));
297       instruction_index++;
298     }
299     // Add the block to the new function.
300     new_function->AddBasicBlock(std::move(block));
301   }
302   // Having considered all the blocks, we should be at the last instruction and
303   // it needs to be OpFunctionEnd.
304   if (instruction_index != num_instructions - 1 ||
305       spv::Op(message_.instruction(instruction_index).opcode()) !=
306           spv::Op::OpFunctionEnd) {
307     return false;
308   }
309   // Set the function's final instruction, add the function to the module and
310   // report success.
311   new_function->SetFunctionEnd(InstructionFromMessage(
312       ir_context, message_.instruction(instruction_index)));
313   ir_context->AddFunction(std::move(new_function));
314 
315   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
316 
317   return true;
318 }
319 
TryToMakeFunctionLivesafe(opt::IRContext * ir_context,const TransformationContext & transformation_context) const320 bool TransformationAddFunction::TryToMakeFunctionLivesafe(
321     opt::IRContext* ir_context,
322     const TransformationContext& transformation_context) const {
323   assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
324 
325   // Get a pointer to the added function.
326   opt::Function* added_function = nullptr;
327   for (auto& function : *ir_context->module()) {
328     if (function.result_id() == message_.instruction(0).result_id()) {
329       added_function = &function;
330       break;
331     }
332   }
333   assert(added_function && "The added function should have been found.");
334 
335   if (!TryToAddLoopLimiters(ir_context, added_function)) {
336     // Adding loop limiters did not work; bail out.
337     return false;
338   }
339 
340   // Consider all the instructions in the function, and:
341   // - attempt to replace OpKill and OpUnreachable with return instructions
342   // - attempt to clamp access chains to be within bounds
343   // - check that OpFunctionCall instructions are only to livesafe functions
344   for (auto& block : *added_function) {
345     for (auto& inst : block) {
346       switch (inst.opcode()) {
347         case spv::Op::OpKill:
348         case spv::Op::OpUnreachable:
349           if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
350                                                     &inst)) {
351             return false;
352           }
353           break;
354         case spv::Op::OpAccessChain:
355         case spv::Op::OpInBoundsAccessChain:
356           if (!TryToClampAccessChainIndices(ir_context, &inst)) {
357             return false;
358           }
359           break;
360         case spv::Op::OpFunctionCall:
361           // A livesafe function my only call other livesafe functions.
362           if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
363                   inst.GetSingleWordInOperand(0))) {
364             return false;
365           }
366         default:
367           break;
368       }
369     }
370   }
371   return true;
372 }
373 
GetBackEdgeBlockId(opt::IRContext * ir_context,uint32_t loop_header_block_id)374 uint32_t TransformationAddFunction::GetBackEdgeBlockId(
375     opt::IRContext* ir_context, uint32_t loop_header_block_id) {
376   const auto* loop_header_block =
377       ir_context->cfg()->block(loop_header_block_id);
378   assert(loop_header_block && "|loop_header_block_id| is invalid");
379 
380   for (auto pred : ir_context->cfg()->preds(loop_header_block_id)) {
381     if (ir_context->GetDominatorAnalysis(loop_header_block->GetParent())
382             ->Dominates(loop_header_block_id, pred)) {
383       return pred;
384     }
385   }
386 
387   return 0;
388 }
389 
TryToAddLoopLimiters(opt::IRContext * ir_context,opt::Function * added_function) const390 bool TransformationAddFunction::TryToAddLoopLimiters(
391     opt::IRContext* ir_context, opt::Function* added_function) const {
392   // Collect up all the loop headers so that we can subsequently add loop
393   // limiting logic.
394   std::vector<opt::BasicBlock*> loop_headers;
395   for (auto& block : *added_function) {
396     if (block.IsLoopHeader()) {
397       loop_headers.push_back(&block);
398     }
399   }
400 
401   if (loop_headers.empty()) {
402     // There are no loops, so no need to add any loop limiters.
403     return true;
404   }
405 
406   // Check that the module contains appropriate ingredients for declaring and
407   // manipulating a loop limiter.
408 
409   auto loop_limit_constant_id_instr =
410       ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
411   if (!loop_limit_constant_id_instr ||
412       loop_limit_constant_id_instr->opcode() != spv::Op::OpConstant) {
413     // The loop limit constant id instruction must exist and have an
414     // appropriate opcode.
415     return false;
416   }
417 
418   auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
419       loop_limit_constant_id_instr->type_id());
420   if (loop_limit_type->opcode() != spv::Op::OpTypeInt ||
421       loop_limit_type->GetSingleWordInOperand(0) != 32) {
422     // The type of the loop limit constant must be 32-bit integer.  It
423     // doesn't actually matter whether the integer is signed or not.
424     return false;
425   }
426 
427   // Find the id of the "unsigned int" type.
428   opt::analysis::Integer unsigned_int_type(32, false);
429   uint32_t unsigned_int_type_id =
430       ir_context->get_type_mgr()->GetId(&unsigned_int_type);
431   if (!unsigned_int_type_id) {
432     // Unsigned int is not available; we need this type in order to add loop
433     // limiters.
434     return false;
435   }
436   auto registered_unsigned_int_type =
437       ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
438 
439   // Look for 0 of type unsigned int.
440   opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
441                                   {0});
442   auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
443   if (!registered_zero) {
444     // We need 0 in order to be able to initialize loop limiters.
445     return false;
446   }
447   uint32_t zero_id = ir_context->get_constant_mgr()
448                          ->GetDefiningInstruction(registered_zero)
449                          ->result_id();
450 
451   // Look for 1 of type unsigned int.
452   opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
453                                  {1});
454   auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
455   if (!registered_one) {
456     // We need 1 in order to be able to increment loop limiters.
457     return false;
458   }
459   uint32_t one_id = ir_context->get_constant_mgr()
460                         ->GetDefiningInstruction(registered_one)
461                         ->result_id();
462 
463   // Look for pointer-to-unsigned int type.
464   opt::analysis::Pointer pointer_to_unsigned_int_type(
465       registered_unsigned_int_type, spv::StorageClass::Function);
466   uint32_t pointer_to_unsigned_int_type_id =
467       ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
468   if (!pointer_to_unsigned_int_type_id) {
469     // We need pointer-to-unsigned int in order to declare the loop limiter
470     // variable.
471     return false;
472   }
473 
474   // Look for bool type.
475   opt::analysis::Bool bool_type;
476   uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
477   if (!bool_type_id) {
478     // We need bool in order to compare the loop limiter's value with the loop
479     // limit constant.
480     return false;
481   }
482 
483   // Declare the loop limiter variable at the start of the function's entry
484   // block, via an instruction of the form:
485   //   %loop_limiter_var = spv::Op::OpVariable %ptr_to_uint Function %zero
486   added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
487       ir_context, spv::Op::OpVariable, pointer_to_unsigned_int_type_id,
488       message_.loop_limiter_variable_id(),
489       opt::Instruction::OperandList({{SPV_OPERAND_TYPE_STORAGE_CLASS,
490                                       {uint32_t(spv::StorageClass::Function)}},
491                                      {SPV_OPERAND_TYPE_ID, {zero_id}}})));
492   // Update the module's id bound since we have added the loop limiter
493   // variable id.
494   fuzzerutil::UpdateModuleIdBound(ir_context,
495                                   message_.loop_limiter_variable_id());
496 
497   // Consider each loop in turn.
498   for (auto loop_header : loop_headers) {
499     // Look for the loop's back-edge block.  This is a predecessor of the loop
500     // header that is dominated by the loop header.
501     const auto back_edge_block_id =
502         GetBackEdgeBlockId(ir_context, loop_header->id());
503     if (!back_edge_block_id) {
504       // The loop's back-edge block must be unreachable.  This means that the
505       // loop cannot iterate, so there is no need to make it lifesafe; we can
506       // move on from this loop.
507       continue;
508     }
509 
510     // If the loop's merge block is unreachable, then there are no constraints
511     // on where the merge block appears in relation to the blocks of the loop.
512     // This means we need to be careful when adding a branch from the back-edge
513     // block to the merge block: the branch might make the loop merge reachable,
514     // and it might then be dominated by the loop header and possibly by other
515     // blocks in the loop. Since a block needs to appear before those blocks it
516     // strictly dominates, this could make the module invalid. To avoid this
517     // problem we bail out in the case where the loop header does not dominate
518     // the loop merge.
519     if (!ir_context->GetDominatorAnalysis(added_function)
520              ->Dominates(loop_header->id(), loop_header->MergeBlockId())) {
521       return false;
522     }
523 
524     // Go through the sequence of loop limiter infos and find the one
525     // corresponding to this loop.
526     bool found = false;
527     protobufs::LoopLimiterInfo loop_limiter_info;
528     for (auto& info : message_.loop_limiter_info()) {
529       if (info.loop_header_id() == loop_header->id()) {
530         loop_limiter_info = info;
531         found = true;
532         break;
533       }
534     }
535     if (!found) {
536       // We don't have loop limiter info for this loop header.
537       return false;
538     }
539 
540     // The back-edge block either has the form:
541     //
542     // (1)
543     //
544     // %l = OpLabel
545     //      ... instructions ...
546     //      OpBranch %loop_header
547     //
548     // (2)
549     //
550     // %l = OpLabel
551     //      ... instructions ...
552     //      OpBranchConditional %c %loop_header %loop_merge
553     //
554     // (3)
555     //
556     // %l = OpLabel
557     //      ... instructions ...
558     //      OpBranchConditional %c %loop_merge %loop_header
559     //
560     // We turn these into the following:
561     //
562     // (1)
563     //
564     //  %l = OpLabel
565     //       ... instructions ...
566     // %t1 = OpLoad %uint32 %loop_limiter
567     // %t2 = OpIAdd %uint32 %t1 %one
568     //       OpStore %loop_limiter %t2
569     // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
570     //       OpBranchConditional %t3 %loop_merge %loop_header
571     //
572     // (2)
573     //
574     //  %l = OpLabel
575     //       ... instructions ...
576     // %t1 = OpLoad %uint32 %loop_limiter
577     // %t2 = OpIAdd %uint32 %t1 %one
578     //       OpStore %loop_limiter %t2
579     // %t3 = OpULessThan %bool %t1 %loop_limit
580     // %t4 = OpLogicalAnd %bool %c %t3
581     //       OpBranchConditional %t4 %loop_header %loop_merge
582     //
583     // (3)
584     //
585     //  %l = OpLabel
586     //       ... instructions ...
587     // %t1 = OpLoad %uint32 %loop_limiter
588     // %t2 = OpIAdd %uint32 %t1 %one
589     //       OpStore %loop_limiter %t2
590     // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
591     // %t4 = OpLogicalOr %bool %c %t3
592     //       OpBranchConditional %t4 %loop_merge %loop_header
593 
594     auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
595     auto back_edge_block_terminator = back_edge_block->terminator();
596     bool compare_using_greater_than_equal;
597     if (back_edge_block_terminator->opcode() == spv::Op::OpBranch) {
598       compare_using_greater_than_equal = true;
599     } else {
600       assert(back_edge_block_terminator->opcode() ==
601              spv::Op::OpBranchConditional);
602       assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
603                    loop_header->id() &&
604                back_edge_block_terminator->GetSingleWordInOperand(2) ==
605                    loop_header->MergeBlockId()) ||
606               (back_edge_block_terminator->GetSingleWordInOperand(2) ==
607                    loop_header->id() &&
608                back_edge_block_terminator->GetSingleWordInOperand(1) ==
609                    loop_header->MergeBlockId())) &&
610              "A back edge edge block must branch to"
611              " either the loop header or merge");
612       compare_using_greater_than_equal =
613           back_edge_block_terminator->GetSingleWordInOperand(1) ==
614           loop_header->MergeBlockId();
615     }
616 
617     std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
618 
619     // Add a load from the loop limiter variable, of the form:
620     //   %t1 = OpLoad %uint32 %loop_limiter
621     new_instructions.push_back(MakeUnique<opt::Instruction>(
622         ir_context, spv::Op::OpLoad, unsigned_int_type_id,
623         loop_limiter_info.load_id(),
624         opt::Instruction::OperandList(
625             {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
626 
627     // Increment the loaded value:
628     //   %t2 = OpIAdd %uint32 %t1 %one
629     new_instructions.push_back(MakeUnique<opt::Instruction>(
630         ir_context, spv::Op::OpIAdd, unsigned_int_type_id,
631         loop_limiter_info.increment_id(),
632         opt::Instruction::OperandList(
633             {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
634              {SPV_OPERAND_TYPE_ID, {one_id}}})));
635 
636     // Store the incremented value back to the loop limiter variable:
637     //   OpStore %loop_limiter %t2
638     new_instructions.push_back(MakeUnique<opt::Instruction>(
639         ir_context, spv::Op::OpStore, 0, 0,
640         opt::Instruction::OperandList(
641             {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
642              {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
643 
644     // Compare the loaded value with the loop limit; either:
645     //   %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
646     // or
647     //   %t3 = OpULessThan %bool %t1 %loop_limit
648     new_instructions.push_back(MakeUnique<opt::Instruction>(
649         ir_context,
650         compare_using_greater_than_equal ? spv::Op::OpUGreaterThanEqual
651                                          : spv::Op::OpULessThan,
652         bool_type_id, loop_limiter_info.compare_id(),
653         opt::Instruction::OperandList(
654             {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
655              {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
656 
657     if (back_edge_block_terminator->opcode() == spv::Op::OpBranchConditional) {
658       new_instructions.push_back(MakeUnique<opt::Instruction>(
659           ir_context,
660           compare_using_greater_than_equal ? spv::Op::OpLogicalOr
661                                            : spv::Op::OpLogicalAnd,
662           bool_type_id, loop_limiter_info.logical_op_id(),
663           opt::Instruction::OperandList(
664               {{SPV_OPERAND_TYPE_ID,
665                 {back_edge_block_terminator->GetSingleWordInOperand(0)}},
666                {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
667     }
668 
669     // Add the new instructions at the end of the back edge block, before the
670     // terminator and any loop merge instruction (as the back edge block can
671     // be the loop header).
672     if (back_edge_block->GetLoopMergeInst()) {
673       back_edge_block->GetLoopMergeInst()->InsertBefore(
674           std::move(new_instructions));
675     } else {
676       back_edge_block_terminator->InsertBefore(std::move(new_instructions));
677     }
678 
679     if (back_edge_block_terminator->opcode() == spv::Op::OpBranchConditional) {
680       back_edge_block_terminator->SetInOperand(
681           0, {loop_limiter_info.logical_op_id()});
682     } else {
683       assert(back_edge_block_terminator->opcode() == spv::Op::OpBranch &&
684              "Back-edge terminator must be OpBranch or OpBranchConditional");
685 
686       // Check that, if the merge block starts with OpPhi instructions, suitable
687       // ids have been provided to give these instructions a value corresponding
688       // to the new incoming edge from the back edge block.
689       auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
690       if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
691                                           merge_block,
692                                           loop_limiter_info.phi_id())) {
693         return false;
694       }
695 
696       // Augment OpPhi instructions at the loop merge with the given ids.
697       uint32_t phi_index = 0;
698       for (auto& inst : *merge_block) {
699         if (inst.opcode() != spv::Op::OpPhi) {
700           break;
701         }
702         assert(phi_index <
703                    static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
704                "There should be at least one phi id per OpPhi instruction.");
705         inst.AddOperand(
706             {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
707         inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
708         phi_index++;
709       }
710 
711       // Add the new edge, by changing OpBranch to OpBranchConditional.
712       back_edge_block_terminator->SetOpcode(spv::Op::OpBranchConditional);
713       back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
714           {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
715            {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}},
716            {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
717     }
718 
719     // Update the module's id bound with respect to the various ids that
720     // have been used for loop limiter manipulation.
721     fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
722     fuzzerutil::UpdateModuleIdBound(ir_context,
723                                     loop_limiter_info.increment_id());
724     fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
725     fuzzerutil::UpdateModuleIdBound(ir_context,
726                                     loop_limiter_info.logical_op_id());
727   }
728   return true;
729 }
730 
TryToTurnKillOrUnreachableIntoReturn(opt::IRContext * ir_context,opt::Function * added_function,opt::Instruction * kill_or_unreachable_inst) const731 bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
732     opt::IRContext* ir_context, opt::Function* added_function,
733     opt::Instruction* kill_or_unreachable_inst) const {
734   assert((kill_or_unreachable_inst->opcode() == spv::Op::OpKill ||
735           kill_or_unreachable_inst->opcode() == spv::Op::OpUnreachable) &&
736          "Precondition: instruction must be OpKill or OpUnreachable.");
737 
738   // Get the function's return type.
739   auto function_return_type_inst =
740       ir_context->get_def_use_mgr()->GetDef(added_function->type_id());
741 
742   if (function_return_type_inst->opcode() == spv::Op::OpTypeVoid) {
743     // The function has void return type, so change this instruction to
744     // OpReturn.
745     kill_or_unreachable_inst->SetOpcode(spv::Op::OpReturn);
746   } else {
747     // The function has non-void return type, so change this instruction
748     // to OpReturnValue, using the value id provided with the
749     // transformation.
750 
751     // We first check that the id, %id, provided with the transformation
752     // specifically to turn OpKill and OpUnreachable instructions into
753     // OpReturnValue %id has the same type as the function's return type.
754     if (ir_context->get_def_use_mgr()
755             ->GetDef(message_.kill_unreachable_return_value_id())
756             ->type_id() != function_return_type_inst->result_id()) {
757       return false;
758     }
759     kill_or_unreachable_inst->SetOpcode(spv::Op::OpReturnValue);
760     kill_or_unreachable_inst->SetInOperands(
761         {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
762   }
763   return true;
764 }
765 
TryToClampAccessChainIndices(opt::IRContext * ir_context,opt::Instruction * access_chain_inst) const766 bool TransformationAddFunction::TryToClampAccessChainIndices(
767     opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
768   assert((access_chain_inst->opcode() == spv::Op::OpAccessChain ||
769           access_chain_inst->opcode() == spv::Op::OpInBoundsAccessChain) &&
770          "Precondition: instruction must be OpAccessChain or "
771          "OpInBoundsAccessChain.");
772 
773   // Find the AccessChainClampingInfo associated with this access chain.
774   const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
775       nullptr;
776   for (auto& clamping_info : message_.access_chain_clamping_info()) {
777     if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
778       access_chain_clamping_info = &clamping_info;
779       break;
780     }
781   }
782   if (!access_chain_clamping_info) {
783     // No access chain clamping information was found; the function cannot be
784     // made livesafe.
785     return false;
786   }
787 
788   // Check that there is a (compare_id, select_id) pair for every
789   // index associated with the instruction.
790   if (static_cast<uint32_t>(
791           access_chain_clamping_info->compare_and_select_ids().size()) !=
792       access_chain_inst->NumInOperands() - 1) {
793     return false;
794   }
795 
796   // Walk the access chain, clamping each index to be within bounds if it is
797   // not a constant.
798   auto base_object = ir_context->get_def_use_mgr()->GetDef(
799       access_chain_inst->GetSingleWordInOperand(0));
800   assert(base_object && "The base object must exist.");
801   auto pointer_type =
802       ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
803   assert(pointer_type && pointer_type->opcode() == spv::Op::OpTypePointer &&
804          "The base object must have pointer type.");
805   auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
806       pointer_type->GetSingleWordInOperand(1));
807 
808   // Consider each index input operand in turn (operand 0 is the base object).
809   for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
810        index++) {
811     // We are going to turn:
812     //
813     // %result = OpAccessChain %type %object ... %index ...
814     //
815     // into:
816     //
817     // %t1 = OpULessThanEqual %bool %index %bound_minus_one
818     // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
819     // %result = OpAccessChain %type %object ... %t2 ...
820     //
821     // ... unless %index is already a constant.
822 
823     // Get the bound for the composite being indexed into; e.g. the number of
824     // columns of matrix or the size of an array.
825     uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
826         *should_be_composite_type, ir_context);
827 
828     // Get the instruction associated with the index and figure out its integer
829     // type.
830     const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
831     auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
832     auto index_type_inst =
833         ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
834     assert(index_type_inst->opcode() == spv::Op::OpTypeInt);
835     assert(index_type_inst->GetSingleWordInOperand(0) == 32);
836     opt::analysis::Integer* index_int_type =
837         ir_context->get_type_mgr()
838             ->GetType(index_type_inst->result_id())
839             ->AsInteger();
840 
841     if (index_inst->opcode() != spv::Op::OpConstant ||
842         index_inst->GetSingleWordInOperand(0) >= bound) {
843       // The index is either non-constant or an out-of-bounds constant, so we
844       // need to clamp it.
845       assert(should_be_composite_type->opcode() != spv::Op::OpTypeStruct &&
846              "Access chain indices into structures are required to be "
847              "constants.");
848       opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
849       if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
850         // We do not have an integer constant whose value is |bound| -1.
851         return false;
852       }
853 
854       opt::analysis::Bool bool_type;
855       uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
856       if (!bool_type_id) {
857         // Bool type is not declared; we cannot do a comparison.
858         return false;
859       }
860 
861       uint32_t bound_minus_one_id =
862           ir_context->get_constant_mgr()
863               ->GetDefiningInstruction(&bound_minus_one)
864               ->result_id();
865 
866       uint32_t compare_id =
867           access_chain_clamping_info->compare_and_select_ids(index - 1).first();
868       uint32_t select_id =
869           access_chain_clamping_info->compare_and_select_ids(index - 1)
870               .second();
871       std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
872 
873       // Compare the index with the bound via an instruction of the form:
874       //   %t1 = OpULessThanEqual %bool %index %bound_minus_one
875       new_instructions.push_back(MakeUnique<opt::Instruction>(
876           ir_context, spv::Op::OpULessThanEqual, bool_type_id, compare_id,
877           opt::Instruction::OperandList(
878               {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
879                {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
880 
881       // Select the index if in-bounds, otherwise one less than the bound:
882       //   %t2 = OpSelect %int_type %t1 %index %bound_minus_one
883       new_instructions.push_back(MakeUnique<opt::Instruction>(
884           ir_context, spv::Op::OpSelect, index_type_inst->result_id(),
885           select_id,
886           opt::Instruction::OperandList(
887               {{SPV_OPERAND_TYPE_ID, {compare_id}},
888                {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
889                {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
890 
891       // Add the new instructions before the access chain
892       access_chain_inst->InsertBefore(std::move(new_instructions));
893 
894       // Replace %index with %t2.
895       access_chain_inst->SetInOperand(index, {select_id});
896       fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
897       fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
898     }
899     should_be_composite_type =
900         FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
901   }
902   return true;
903 }
904 
FollowCompositeIndex(opt::IRContext * ir_context,const opt::Instruction & composite_type_inst,uint32_t index_id)905 opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
906     opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
907     uint32_t index_id) {
908   uint32_t sub_object_type_id;
909   switch (composite_type_inst.opcode()) {
910     case spv::Op::OpTypeArray:
911     case spv::Op::OpTypeRuntimeArray:
912       sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
913       break;
914     case spv::Op::OpTypeMatrix:
915     case spv::Op::OpTypeVector:
916       sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
917       break;
918     case spv::Op::OpTypeStruct: {
919       auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
920       assert(index_inst->opcode() == spv::Op::OpConstant);
921       assert(ir_context->get_def_use_mgr()
922                  ->GetDef(index_inst->type_id())
923                  ->opcode() == spv::Op::OpTypeInt);
924       assert(ir_context->get_def_use_mgr()
925                  ->GetDef(index_inst->type_id())
926                  ->GetSingleWordInOperand(0) == 32);
927       uint32_t index_value = index_inst->GetSingleWordInOperand(0);
928       sub_object_type_id =
929           composite_type_inst.GetSingleWordInOperand(index_value);
930       break;
931     }
932     default:
933       assert(false && "Unknown composite type.");
934       sub_object_type_id = 0;
935       break;
936   }
937   assert(sub_object_type_id && "No sub-object found.");
938   return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
939 }
940 
GetFreshIds() const941 std::unordered_set<uint32_t> TransformationAddFunction::GetFreshIds() const {
942   std::unordered_set<uint32_t> result;
943   for (auto& instruction : message_.instruction()) {
944     result.insert(instruction.result_id());
945   }
946   if (message_.is_livesafe()) {
947     result.insert(message_.loop_limiter_variable_id());
948     for (auto& loop_limiter_info : message_.loop_limiter_info()) {
949       result.insert(loop_limiter_info.load_id());
950       result.insert(loop_limiter_info.increment_id());
951       result.insert(loop_limiter_info.compare_id());
952       result.insert(loop_limiter_info.logical_op_id());
953     }
954     for (auto& access_chain_clamping_info :
955          message_.access_chain_clamping_info()) {
956       for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
957         result.insert(pair.first());
958         result.insert(pair.second());
959       }
960     }
961   }
962   return result;
963 }
964 
965 }  // namespace fuzz
966 }  // namespace spvtools
967