• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_add_function.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_message.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 
TransformationAddFunction(protobufs::TransformationAddFunction message)23 TransformationAddFunction::TransformationAddFunction(
24     protobufs::TransformationAddFunction message)
25     : message_(std::move(message)) {}
26 
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions)27 TransformationAddFunction::TransformationAddFunction(
28     const std::vector<protobufs::Instruction>& instructions) {
29   for (auto& instruction : instructions) {
30     *message_.add_instruction() = instruction;
31   }
32   message_.set_is_livesafe(false);
33 }
34 
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions,uint32_t loop_limiter_variable_id,uint32_t loop_limit_constant_id,const std::vector<protobufs::LoopLimiterInfo> & loop_limiters,uint32_t kill_unreachable_return_value_id,const std::vector<protobufs::AccessChainClampingInfo> & access_chain_clampers)35 TransformationAddFunction::TransformationAddFunction(
36     const std::vector<protobufs::Instruction>& instructions,
37     uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
38     const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
39     uint32_t kill_unreachable_return_value_id,
40     const std::vector<protobufs::AccessChainClampingInfo>&
41         access_chain_clampers) {
42   for (auto& instruction : instructions) {
43     *message_.add_instruction() = instruction;
44   }
45   message_.set_is_livesafe(true);
46   message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
47   message_.set_loop_limit_constant_id(loop_limit_constant_id);
48   for (auto& loop_limiter : loop_limiters) {
49     *message_.add_loop_limiter_info() = loop_limiter;
50   }
51   message_.set_kill_unreachable_return_value_id(
52       kill_unreachable_return_value_id);
53   for (auto& access_clamper : access_chain_clampers) {
54     *message_.add_access_chain_clamping_info() = access_clamper;
55   }
56 }
57 
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const58 bool TransformationAddFunction::IsApplicable(
59     opt::IRContext* ir_context,
60     const TransformationContext& transformation_context) const {
61   // This transformation may use a lot of ids, all of which need to be fresh
62   // and distinct.  This set tracks them.
63   std::set<uint32_t> ids_used_by_this_transformation;
64 
65   // Ensure that all result ids in the new function are fresh and distinct.
66   for (auto& instruction : message_.instruction()) {
67     if (instruction.result_id()) {
68       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
69               instruction.result_id(), ir_context,
70               &ids_used_by_this_transformation)) {
71         return false;
72       }
73     }
74   }
75 
76   if (message_.is_livesafe()) {
77     // Ensure that all ids provided for making the function livesafe are fresh
78     // and distinct.
79     if (!CheckIdIsFreshAndNotUsedByThisTransformation(
80             message_.loop_limiter_variable_id(), ir_context,
81             &ids_used_by_this_transformation)) {
82       return false;
83     }
84     for (auto& loop_limiter_info : message_.loop_limiter_info()) {
85       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
86               loop_limiter_info.load_id(), ir_context,
87               &ids_used_by_this_transformation)) {
88         return false;
89       }
90       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
91               loop_limiter_info.increment_id(), ir_context,
92               &ids_used_by_this_transformation)) {
93         return false;
94       }
95       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
96               loop_limiter_info.compare_id(), ir_context,
97               &ids_used_by_this_transformation)) {
98         return false;
99       }
100       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
101               loop_limiter_info.logical_op_id(), ir_context,
102               &ids_used_by_this_transformation)) {
103         return false;
104       }
105     }
106     for (auto& access_chain_clamping_info :
107          message_.access_chain_clamping_info()) {
108       for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
109         if (!CheckIdIsFreshAndNotUsedByThisTransformation(
110                 pair.first(), ir_context, &ids_used_by_this_transformation)) {
111           return false;
112         }
113         if (!CheckIdIsFreshAndNotUsedByThisTransformation(
114                 pair.second(), ir_context, &ids_used_by_this_transformation)) {
115           return false;
116         }
117       }
118     }
119   }
120 
121   // Because checking all the conditions for a function to be valid is a big
122   // job that the SPIR-V validator can already do, a "try it and see" approach
123   // is taken here.
124 
125   // We first clone the current module, so that we can try adding the new
126   // function without risking wrecking |ir_context|.
127   auto cloned_module = fuzzerutil::CloneIRContext(ir_context);
128 
129   // We try to add a function to the cloned module, which may fail if
130   // |message_.instruction| is not sufficiently well-formed.
131   if (!TryToAddFunction(cloned_module.get())) {
132     return false;
133   }
134 
135   // Check whether the cloned module is still valid after adding the function.
136   // If it is not, the transformation is not applicable.
137   if (!fuzzerutil::IsValid(cloned_module.get(),
138                            transformation_context.GetValidatorOptions(),
139                            fuzzerutil::kSilentMessageConsumer)) {
140     return false;
141   }
142 
143   if (message_.is_livesafe()) {
144     if (!TryToMakeFunctionLivesafe(cloned_module.get(),
145                                    transformation_context)) {
146       return false;
147     }
148     // After making the function livesafe, we check validity of the module
149     // again.  This is because the turning of OpKill, OpUnreachable and OpReturn
150     // instructions into branches changes control flow graph reachability, which
151     // has the potential to make the module invalid when it was otherwise valid.
152     // It is simpler to rely on the validator to guard against this than to
153     // consider all scenarios when making a function livesafe.
154     if (!fuzzerutil::IsValid(cloned_module.get(),
155                              transformation_context.GetValidatorOptions(),
156                              fuzzerutil::kSilentMessageConsumer)) {
157       return false;
158     }
159   }
160   return true;
161 }
162 
Apply(opt::IRContext * ir_context,TransformationContext * transformation_context) const163 void TransformationAddFunction::Apply(
164     opt::IRContext* ir_context,
165     TransformationContext* transformation_context) const {
166   // Add the function to the module.  As the transformation is applicable, this
167   // should succeed.
168   bool success = TryToAddFunction(ir_context);
169   assert(success && "The function should be successfully added.");
170   (void)(success);  // Keep release builds happy (otherwise they may complain
171                     // that |success| is not used).
172 
173   if (message_.is_livesafe()) {
174     // Make the function livesafe, which also should succeed.
175     success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
176     assert(success && "It should be possible to make the function livesafe.");
177     (void)(success);  // Keep release builds happy.
178   }
179   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
180 
181   assert(message_.instruction(0).opcode() == SpvOpFunction &&
182          "The first instruction of an 'add function' transformation must be "
183          "OpFunction.");
184 
185   if (message_.is_livesafe()) {
186     // Inform the fact manager that the function is livesafe.
187     transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
188         message_.instruction(0).result_id());
189   } else {
190     // Inform the fact manager that all blocks in the function are dead.
191     for (auto& inst : message_.instruction()) {
192       if (inst.opcode() == SpvOpLabel) {
193         transformation_context->GetFactManager()->AddFactBlockIsDead(
194             inst.result_id());
195       }
196     }
197   }
198 
199   // Record the fact that all pointer parameters and variables declared in the
200   // function should be regarded as having irrelevant values.  This allows other
201   // passes to store arbitrarily to such variables, and to pass them freely as
202   // parameters to other functions knowing that it is OK if they get
203   // over-written.
204   for (auto& instruction : message_.instruction()) {
205     switch (instruction.opcode()) {
206       case SpvOpFunctionParameter:
207         if (ir_context->get_def_use_mgr()
208                 ->GetDef(instruction.result_type_id())
209                 ->opcode() == SpvOpTypePointer) {
210           transformation_context->GetFactManager()
211               ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
212         }
213         break;
214       case SpvOpVariable:
215         transformation_context->GetFactManager()
216             ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
217         break;
218       default:
219         break;
220     }
221   }
222 }
223 
ToMessage() const224 protobufs::Transformation TransformationAddFunction::ToMessage() const {
225   protobufs::Transformation result;
226   *result.mutable_add_function() = message_;
227   return result;
228 }
229 
TryToAddFunction(opt::IRContext * ir_context) const230 bool TransformationAddFunction::TryToAddFunction(
231     opt::IRContext* ir_context) const {
232   // This function returns false if |message_.instruction| was not well-formed
233   // enough to actually create a function and add it to |ir_context|.
234 
235   // A function must have at least some instructions.
236   if (message_.instruction().empty()) {
237     return false;
238   }
239 
240   // A function must start with OpFunction.
241   auto function_begin = message_.instruction(0);
242   if (function_begin.opcode() != SpvOpFunction) {
243     return false;
244   }
245 
246   // Make a function, headed by the OpFunction instruction.
247   std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
248       InstructionFromMessage(ir_context, function_begin));
249 
250   // Keeps track of which instruction protobuf message we are currently
251   // considering.
252   uint32_t instruction_index = 1;
253   const auto num_instructions =
254       static_cast<uint32_t>(message_.instruction().size());
255 
256   // Iterate through all function parameter instructions, adding parameters to
257   // the new function.
258   while (instruction_index < num_instructions &&
259          message_.instruction(instruction_index).opcode() ==
260              SpvOpFunctionParameter) {
261     new_function->AddParameter(InstructionFromMessage(
262         ir_context, message_.instruction(instruction_index)));
263     instruction_index++;
264   }
265 
266   // After the parameters, there needs to be a label.
267   if (instruction_index == num_instructions ||
268       message_.instruction(instruction_index).opcode() != SpvOpLabel) {
269     return false;
270   }
271 
272   // Iterate through the instructions block by block until the end of the
273   // function is reached.
274   while (instruction_index < num_instructions &&
275          message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
276     // Invariant: we should always be at a label instruction at this point.
277     assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
278 
279     // Make a basic block using the label instruction.
280     std::unique_ptr<opt::BasicBlock> block =
281         MakeUnique<opt::BasicBlock>(InstructionFromMessage(
282             ir_context, message_.instruction(instruction_index)));
283 
284     // Consider successive instructions until we hit another label or the end
285     // of the function, adding each such instruction to the block.
286     instruction_index++;
287     while (instruction_index < num_instructions &&
288            message_.instruction(instruction_index).opcode() !=
289                SpvOpFunctionEnd &&
290            message_.instruction(instruction_index).opcode() != SpvOpLabel) {
291       block->AddInstruction(InstructionFromMessage(
292           ir_context, message_.instruction(instruction_index)));
293       instruction_index++;
294     }
295     // Add the block to the new function.
296     new_function->AddBasicBlock(std::move(block));
297   }
298   // Having considered all the blocks, we should be at the last instruction and
299   // it needs to be OpFunctionEnd.
300   if (instruction_index != num_instructions - 1 ||
301       message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
302     return false;
303   }
304   // Set the function's final instruction, add the function to the module and
305   // report success.
306   new_function->SetFunctionEnd(InstructionFromMessage(
307       ir_context, message_.instruction(instruction_index)));
308   ir_context->AddFunction(std::move(new_function));
309 
310   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
311 
312   return true;
313 }
314 
TryToMakeFunctionLivesafe(opt::IRContext * ir_context,const TransformationContext & transformation_context) const315 bool TransformationAddFunction::TryToMakeFunctionLivesafe(
316     opt::IRContext* ir_context,
317     const TransformationContext& transformation_context) const {
318   assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
319 
320   // Get a pointer to the added function.
321   opt::Function* added_function = nullptr;
322   for (auto& function : *ir_context->module()) {
323     if (function.result_id() == message_.instruction(0).result_id()) {
324       added_function = &function;
325       break;
326     }
327   }
328   assert(added_function && "The added function should have been found.");
329 
330   if (!TryToAddLoopLimiters(ir_context, added_function)) {
331     // Adding loop limiters did not work; bail out.
332     return false;
333   }
334 
335   // Consider all the instructions in the function, and:
336   // - attempt to replace OpKill and OpUnreachable with return instructions
337   // - attempt to clamp access chains to be within bounds
338   // - check that OpFunctionCall instructions are only to livesafe functions
339   for (auto& block : *added_function) {
340     for (auto& inst : block) {
341       switch (inst.opcode()) {
342         case SpvOpKill:
343         case SpvOpUnreachable:
344           if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
345                                                     &inst)) {
346             return false;
347           }
348           break;
349         case SpvOpAccessChain:
350         case SpvOpInBoundsAccessChain:
351           if (!TryToClampAccessChainIndices(ir_context, &inst)) {
352             return false;
353           }
354           break;
355         case SpvOpFunctionCall:
356           // A livesafe function my only call other livesafe functions.
357           if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
358                   inst.GetSingleWordInOperand(0))) {
359             return false;
360           }
361         default:
362           break;
363       }
364     }
365   }
366   return true;
367 }
368 
GetBackEdgeBlockId(opt::IRContext * ir_context,uint32_t loop_header_block_id)369 uint32_t TransformationAddFunction::GetBackEdgeBlockId(
370     opt::IRContext* ir_context, uint32_t loop_header_block_id) {
371   const auto* loop_header_block =
372       ir_context->cfg()->block(loop_header_block_id);
373   assert(loop_header_block && "|loop_header_block_id| is invalid");
374 
375   for (auto pred : ir_context->cfg()->preds(loop_header_block_id)) {
376     if (ir_context->GetDominatorAnalysis(loop_header_block->GetParent())
377             ->Dominates(loop_header_block_id, pred)) {
378       return pred;
379     }
380   }
381 
382   return 0;
383 }
384 
TryToAddLoopLimiters(opt::IRContext * ir_context,opt::Function * added_function) const385 bool TransformationAddFunction::TryToAddLoopLimiters(
386     opt::IRContext* ir_context, opt::Function* added_function) const {
387   // Collect up all the loop headers so that we can subsequently add loop
388   // limiting logic.
389   std::vector<opt::BasicBlock*> loop_headers;
390   for (auto& block : *added_function) {
391     if (block.IsLoopHeader()) {
392       loop_headers.push_back(&block);
393     }
394   }
395 
396   if (loop_headers.empty()) {
397     // There are no loops, so no need to add any loop limiters.
398     return true;
399   }
400 
401   // Check that the module contains appropriate ingredients for declaring and
402   // manipulating a loop limiter.
403 
404   auto loop_limit_constant_id_instr =
405       ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
406   if (!loop_limit_constant_id_instr ||
407       loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
408     // The loop limit constant id instruction must exist and have an
409     // appropriate opcode.
410     return false;
411   }
412 
413   auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
414       loop_limit_constant_id_instr->type_id());
415   if (loop_limit_type->opcode() != SpvOpTypeInt ||
416       loop_limit_type->GetSingleWordInOperand(0) != 32) {
417     // The type of the loop limit constant must be 32-bit integer.  It
418     // doesn't actually matter whether the integer is signed or not.
419     return false;
420   }
421 
422   // Find the id of the "unsigned int" type.
423   opt::analysis::Integer unsigned_int_type(32, false);
424   uint32_t unsigned_int_type_id =
425       ir_context->get_type_mgr()->GetId(&unsigned_int_type);
426   if (!unsigned_int_type_id) {
427     // Unsigned int is not available; we need this type in order to add loop
428     // limiters.
429     return false;
430   }
431   auto registered_unsigned_int_type =
432       ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
433 
434   // Look for 0 of type unsigned int.
435   opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
436                                   {0});
437   auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
438   if (!registered_zero) {
439     // We need 0 in order to be able to initialize loop limiters.
440     return false;
441   }
442   uint32_t zero_id = ir_context->get_constant_mgr()
443                          ->GetDefiningInstruction(registered_zero)
444                          ->result_id();
445 
446   // Look for 1 of type unsigned int.
447   opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
448                                  {1});
449   auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
450   if (!registered_one) {
451     // We need 1 in order to be able to increment loop limiters.
452     return false;
453   }
454   uint32_t one_id = ir_context->get_constant_mgr()
455                         ->GetDefiningInstruction(registered_one)
456                         ->result_id();
457 
458   // Look for pointer-to-unsigned int type.
459   opt::analysis::Pointer pointer_to_unsigned_int_type(
460       registered_unsigned_int_type, SpvStorageClassFunction);
461   uint32_t pointer_to_unsigned_int_type_id =
462       ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
463   if (!pointer_to_unsigned_int_type_id) {
464     // We need pointer-to-unsigned int in order to declare the loop limiter
465     // variable.
466     return false;
467   }
468 
469   // Look for bool type.
470   opt::analysis::Bool bool_type;
471   uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
472   if (!bool_type_id) {
473     // We need bool in order to compare the loop limiter's value with the loop
474     // limit constant.
475     return false;
476   }
477 
478   // Declare the loop limiter variable at the start of the function's entry
479   // block, via an instruction of the form:
480   //   %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
481   added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
482       ir_context, SpvOpVariable, pointer_to_unsigned_int_type_id,
483       message_.loop_limiter_variable_id(),
484       opt::Instruction::OperandList(
485           {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
486            {SPV_OPERAND_TYPE_ID, {zero_id}}})));
487   // Update the module's id bound since we have added the loop limiter
488   // variable id.
489   fuzzerutil::UpdateModuleIdBound(ir_context,
490                                   message_.loop_limiter_variable_id());
491 
492   // Consider each loop in turn.
493   for (auto loop_header : loop_headers) {
494     // Look for the loop's back-edge block.  This is a predecessor of the loop
495     // header that is dominated by the loop header.
496     const auto back_edge_block_id =
497         GetBackEdgeBlockId(ir_context, loop_header->id());
498     if (!back_edge_block_id) {
499       // The loop's back-edge block must be unreachable.  This means that the
500       // loop cannot iterate, so there is no need to make it lifesafe; we can
501       // move on from this loop.
502       continue;
503     }
504 
505     // If the loop's merge block is unreachable, then there are no constraints
506     // on where the merge block appears in relation to the blocks of the loop.
507     // This means we need to be careful when adding a branch from the back-edge
508     // block to the merge block: the branch might make the loop merge reachable,
509     // and it might then be dominated by the loop header and possibly by other
510     // blocks in the loop. Since a block needs to appear before those blocks it
511     // strictly dominates, this could make the module invalid. To avoid this
512     // problem we bail out in the case where the loop header does not dominate
513     // the loop merge.
514     if (!ir_context->GetDominatorAnalysis(added_function)
515              ->Dominates(loop_header->id(), loop_header->MergeBlockId())) {
516       return false;
517     }
518 
519     // Go through the sequence of loop limiter infos and find the one
520     // corresponding to this loop.
521     bool found = false;
522     protobufs::LoopLimiterInfo loop_limiter_info;
523     for (auto& info : message_.loop_limiter_info()) {
524       if (info.loop_header_id() == loop_header->id()) {
525         loop_limiter_info = info;
526         found = true;
527         break;
528       }
529     }
530     if (!found) {
531       // We don't have loop limiter info for this loop header.
532       return false;
533     }
534 
535     // The back-edge block either has the form:
536     //
537     // (1)
538     //
539     // %l = OpLabel
540     //      ... instructions ...
541     //      OpBranch %loop_header
542     //
543     // (2)
544     //
545     // %l = OpLabel
546     //      ... instructions ...
547     //      OpBranchConditional %c %loop_header %loop_merge
548     //
549     // (3)
550     //
551     // %l = OpLabel
552     //      ... instructions ...
553     //      OpBranchConditional %c %loop_merge %loop_header
554     //
555     // We turn these into the following:
556     //
557     // (1)
558     //
559     //  %l = OpLabel
560     //       ... instructions ...
561     // %t1 = OpLoad %uint32 %loop_limiter
562     // %t2 = OpIAdd %uint32 %t1 %one
563     //       OpStore %loop_limiter %t2
564     // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
565     //       OpBranchConditional %t3 %loop_merge %loop_header
566     //
567     // (2)
568     //
569     //  %l = OpLabel
570     //       ... instructions ...
571     // %t1 = OpLoad %uint32 %loop_limiter
572     // %t2 = OpIAdd %uint32 %t1 %one
573     //       OpStore %loop_limiter %t2
574     // %t3 = OpULessThan %bool %t1 %loop_limit
575     // %t4 = OpLogicalAnd %bool %c %t3
576     //       OpBranchConditional %t4 %loop_header %loop_merge
577     //
578     // (3)
579     //
580     //  %l = OpLabel
581     //       ... instructions ...
582     // %t1 = OpLoad %uint32 %loop_limiter
583     // %t2 = OpIAdd %uint32 %t1 %one
584     //       OpStore %loop_limiter %t2
585     // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
586     // %t4 = OpLogicalOr %bool %c %t3
587     //       OpBranchConditional %t4 %loop_merge %loop_header
588 
589     auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
590     auto back_edge_block_terminator = back_edge_block->terminator();
591     bool compare_using_greater_than_equal;
592     if (back_edge_block_terminator->opcode() == SpvOpBranch) {
593       compare_using_greater_than_equal = true;
594     } else {
595       assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
596       assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
597                    loop_header->id() &&
598                back_edge_block_terminator->GetSingleWordInOperand(2) ==
599                    loop_header->MergeBlockId()) ||
600               (back_edge_block_terminator->GetSingleWordInOperand(2) ==
601                    loop_header->id() &&
602                back_edge_block_terminator->GetSingleWordInOperand(1) ==
603                    loop_header->MergeBlockId())) &&
604              "A back edge edge block must branch to"
605              " either the loop header or merge");
606       compare_using_greater_than_equal =
607           back_edge_block_terminator->GetSingleWordInOperand(1) ==
608           loop_header->MergeBlockId();
609     }
610 
611     std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
612 
613     // Add a load from the loop limiter variable, of the form:
614     //   %t1 = OpLoad %uint32 %loop_limiter
615     new_instructions.push_back(MakeUnique<opt::Instruction>(
616         ir_context, SpvOpLoad, unsigned_int_type_id,
617         loop_limiter_info.load_id(),
618         opt::Instruction::OperandList(
619             {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
620 
621     // Increment the loaded value:
622     //   %t2 = OpIAdd %uint32 %t1 %one
623     new_instructions.push_back(MakeUnique<opt::Instruction>(
624         ir_context, SpvOpIAdd, unsigned_int_type_id,
625         loop_limiter_info.increment_id(),
626         opt::Instruction::OperandList(
627             {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
628              {SPV_OPERAND_TYPE_ID, {one_id}}})));
629 
630     // Store the incremented value back to the loop limiter variable:
631     //   OpStore %loop_limiter %t2
632     new_instructions.push_back(MakeUnique<opt::Instruction>(
633         ir_context, SpvOpStore, 0, 0,
634         opt::Instruction::OperandList(
635             {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
636              {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
637 
638     // Compare the loaded value with the loop limit; either:
639     //   %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
640     // or
641     //   %t3 = OpULessThan %bool %t1 %loop_limit
642     new_instructions.push_back(MakeUnique<opt::Instruction>(
643         ir_context,
644         compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
645                                          : SpvOpULessThan,
646         bool_type_id, loop_limiter_info.compare_id(),
647         opt::Instruction::OperandList(
648             {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
649              {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
650 
651     if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
652       new_instructions.push_back(MakeUnique<opt::Instruction>(
653           ir_context,
654           compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
655           bool_type_id, loop_limiter_info.logical_op_id(),
656           opt::Instruction::OperandList(
657               {{SPV_OPERAND_TYPE_ID,
658                 {back_edge_block_terminator->GetSingleWordInOperand(0)}},
659                {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
660     }
661 
662     // Add the new instructions at the end of the back edge block, before the
663     // terminator and any loop merge instruction (as the back edge block can
664     // be the loop header).
665     if (back_edge_block->GetLoopMergeInst()) {
666       back_edge_block->GetLoopMergeInst()->InsertBefore(
667           std::move(new_instructions));
668     } else {
669       back_edge_block_terminator->InsertBefore(std::move(new_instructions));
670     }
671 
672     if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
673       back_edge_block_terminator->SetInOperand(
674           0, {loop_limiter_info.logical_op_id()});
675     } else {
676       assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
677              "Back-edge terminator must be OpBranch or OpBranchConditional");
678 
679       // Check that, if the merge block starts with OpPhi instructions, suitable
680       // ids have been provided to give these instructions a value corresponding
681       // to the new incoming edge from the back edge block.
682       auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
683       if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
684                                           merge_block,
685                                           loop_limiter_info.phi_id())) {
686         return false;
687       }
688 
689       // Augment OpPhi instructions at the loop merge with the given ids.
690       uint32_t phi_index = 0;
691       for (auto& inst : *merge_block) {
692         if (inst.opcode() != SpvOpPhi) {
693           break;
694         }
695         assert(phi_index <
696                    static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
697                "There should be at least one phi id per OpPhi instruction.");
698         inst.AddOperand(
699             {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
700         inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
701         phi_index++;
702       }
703 
704       // Add the new edge, by changing OpBranch to OpBranchConditional.
705       back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
706       back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
707           {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
708            {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}},
709            {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
710     }
711 
712     // Update the module's id bound with respect to the various ids that
713     // have been used for loop limiter manipulation.
714     fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
715     fuzzerutil::UpdateModuleIdBound(ir_context,
716                                     loop_limiter_info.increment_id());
717     fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
718     fuzzerutil::UpdateModuleIdBound(ir_context,
719                                     loop_limiter_info.logical_op_id());
720   }
721   return true;
722 }
723 
TryToTurnKillOrUnreachableIntoReturn(opt::IRContext * ir_context,opt::Function * added_function,opt::Instruction * kill_or_unreachable_inst) const724 bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
725     opt::IRContext* ir_context, opt::Function* added_function,
726     opt::Instruction* kill_or_unreachable_inst) const {
727   assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
728           kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
729          "Precondition: instruction must be OpKill or OpUnreachable.");
730 
731   // Get the function's return type.
732   auto function_return_type_inst =
733       ir_context->get_def_use_mgr()->GetDef(added_function->type_id());
734 
735   if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
736     // The function has void return type, so change this instruction to
737     // OpReturn.
738     kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
739   } else {
740     // The function has non-void return type, so change this instruction
741     // to OpReturnValue, using the value id provided with the
742     // transformation.
743 
744     // We first check that the id, %id, provided with the transformation
745     // specifically to turn OpKill and OpUnreachable instructions into
746     // OpReturnValue %id has the same type as the function's return type.
747     if (ir_context->get_def_use_mgr()
748             ->GetDef(message_.kill_unreachable_return_value_id())
749             ->type_id() != function_return_type_inst->result_id()) {
750       return false;
751     }
752     kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
753     kill_or_unreachable_inst->SetInOperands(
754         {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
755   }
756   return true;
757 }
758 
TryToClampAccessChainIndices(opt::IRContext * ir_context,opt::Instruction * access_chain_inst) const759 bool TransformationAddFunction::TryToClampAccessChainIndices(
760     opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
761   assert((access_chain_inst->opcode() == SpvOpAccessChain ||
762           access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
763          "Precondition: instruction must be OpAccessChain or "
764          "OpInBoundsAccessChain.");
765 
766   // Find the AccessChainClampingInfo associated with this access chain.
767   const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
768       nullptr;
769   for (auto& clamping_info : message_.access_chain_clamping_info()) {
770     if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
771       access_chain_clamping_info = &clamping_info;
772       break;
773     }
774   }
775   if (!access_chain_clamping_info) {
776     // No access chain clamping information was found; the function cannot be
777     // made livesafe.
778     return false;
779   }
780 
781   // Check that there is a (compare_id, select_id) pair for every
782   // index associated with the instruction.
783   if (static_cast<uint32_t>(
784           access_chain_clamping_info->compare_and_select_ids().size()) !=
785       access_chain_inst->NumInOperands() - 1) {
786     return false;
787   }
788 
789   // Walk the access chain, clamping each index to be within bounds if it is
790   // not a constant.
791   auto base_object = ir_context->get_def_use_mgr()->GetDef(
792       access_chain_inst->GetSingleWordInOperand(0));
793   assert(base_object && "The base object must exist.");
794   auto pointer_type =
795       ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
796   assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
797          "The base object must have pointer type.");
798   auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
799       pointer_type->GetSingleWordInOperand(1));
800 
801   // Consider each index input operand in turn (operand 0 is the base object).
802   for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
803        index++) {
804     // We are going to turn:
805     //
806     // %result = OpAccessChain %type %object ... %index ...
807     //
808     // into:
809     //
810     // %t1 = OpULessThanEqual %bool %index %bound_minus_one
811     // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
812     // %result = OpAccessChain %type %object ... %t2 ...
813     //
814     // ... unless %index is already a constant.
815 
816     // Get the bound for the composite being indexed into; e.g. the number of
817     // columns of matrix or the size of an array.
818     uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
819         *should_be_composite_type, ir_context);
820 
821     // Get the instruction associated with the index and figure out its integer
822     // type.
823     const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
824     auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
825     auto index_type_inst =
826         ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
827     assert(index_type_inst->opcode() == SpvOpTypeInt);
828     assert(index_type_inst->GetSingleWordInOperand(0) == 32);
829     opt::analysis::Integer* index_int_type =
830         ir_context->get_type_mgr()
831             ->GetType(index_type_inst->result_id())
832             ->AsInteger();
833 
834     if (index_inst->opcode() != SpvOpConstant ||
835         index_inst->GetSingleWordInOperand(0) >= bound) {
836       // The index is either non-constant or an out-of-bounds constant, so we
837       // need to clamp it.
838       assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
839              "Access chain indices into structures are required to be "
840              "constants.");
841       opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
842       if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
843         // We do not have an integer constant whose value is |bound| -1.
844         return false;
845       }
846 
847       opt::analysis::Bool bool_type;
848       uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
849       if (!bool_type_id) {
850         // Bool type is not declared; we cannot do a comparison.
851         return false;
852       }
853 
854       uint32_t bound_minus_one_id =
855           ir_context->get_constant_mgr()
856               ->GetDefiningInstruction(&bound_minus_one)
857               ->result_id();
858 
859       uint32_t compare_id =
860           access_chain_clamping_info->compare_and_select_ids(index - 1).first();
861       uint32_t select_id =
862           access_chain_clamping_info->compare_and_select_ids(index - 1)
863               .second();
864       std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
865 
866       // Compare the index with the bound via an instruction of the form:
867       //   %t1 = OpULessThanEqual %bool %index %bound_minus_one
868       new_instructions.push_back(MakeUnique<opt::Instruction>(
869           ir_context, SpvOpULessThanEqual, bool_type_id, compare_id,
870           opt::Instruction::OperandList(
871               {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
872                {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
873 
874       // Select the index if in-bounds, otherwise one less than the bound:
875       //   %t2 = OpSelect %int_type %t1 %index %bound_minus_one
876       new_instructions.push_back(MakeUnique<opt::Instruction>(
877           ir_context, SpvOpSelect, index_type_inst->result_id(), select_id,
878           opt::Instruction::OperandList(
879               {{SPV_OPERAND_TYPE_ID, {compare_id}},
880                {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
881                {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
882 
883       // Add the new instructions before the access chain
884       access_chain_inst->InsertBefore(std::move(new_instructions));
885 
886       // Replace %index with %t2.
887       access_chain_inst->SetInOperand(index, {select_id});
888       fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
889       fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
890     }
891     should_be_composite_type =
892         FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
893   }
894   return true;
895 }
896 
FollowCompositeIndex(opt::IRContext * ir_context,const opt::Instruction & composite_type_inst,uint32_t index_id)897 opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
898     opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
899     uint32_t index_id) {
900   uint32_t sub_object_type_id;
901   switch (composite_type_inst.opcode()) {
902     case SpvOpTypeArray:
903     case SpvOpTypeRuntimeArray:
904       sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
905       break;
906     case SpvOpTypeMatrix:
907     case SpvOpTypeVector:
908       sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
909       break;
910     case SpvOpTypeStruct: {
911       auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
912       assert(index_inst->opcode() == SpvOpConstant);
913       assert(ir_context->get_def_use_mgr()
914                  ->GetDef(index_inst->type_id())
915                  ->opcode() == SpvOpTypeInt);
916       assert(ir_context->get_def_use_mgr()
917                  ->GetDef(index_inst->type_id())
918                  ->GetSingleWordInOperand(0) == 32);
919       uint32_t index_value = index_inst->GetSingleWordInOperand(0);
920       sub_object_type_id =
921           composite_type_inst.GetSingleWordInOperand(index_value);
922       break;
923     }
924     default:
925       assert(false && "Unknown composite type.");
926       sub_object_type_id = 0;
927       break;
928   }
929   assert(sub_object_type_id && "No sub-object found.");
930   return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
931 }
932 
GetFreshIds() const933 std::unordered_set<uint32_t> TransformationAddFunction::GetFreshIds() const {
934   std::unordered_set<uint32_t> result;
935   for (auto& instruction : message_.instruction()) {
936     result.insert(instruction.result_id());
937   }
938   if (message_.is_livesafe()) {
939     result.insert(message_.loop_limiter_variable_id());
940     for (auto& loop_limiter_info : message_.loop_limiter_info()) {
941       result.insert(loop_limiter_info.load_id());
942       result.insert(loop_limiter_info.increment_id());
943       result.insert(loop_limiter_info.compare_id());
944       result.insert(loop_limiter_info.logical_op_id());
945     }
946     for (auto& access_chain_clamping_info :
947          message_.access_chain_clamping_info()) {
948       for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
949         result.insert(pair.first());
950         result.insert(pair.second());
951       }
952     }
953   }
954   return result;
955 }
956 
957 }  // namespace fuzz
958 }  // namespace spvtools
959