• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_add_function.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_message.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 
TransformationAddFunction(const spvtools::fuzz::protobufs::TransformationAddFunction & message)23 TransformationAddFunction::TransformationAddFunction(
24     const spvtools::fuzz::protobufs::TransformationAddFunction& message)
25     : message_(message) {}
26 
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions)27 TransformationAddFunction::TransformationAddFunction(
28     const std::vector<protobufs::Instruction>& instructions) {
29   for (auto& instruction : instructions) {
30     *message_.add_instruction() = instruction;
31   }
32   message_.set_is_livesafe(false);
33 }
34 
TransformationAddFunction(const std::vector<protobufs::Instruction> & instructions,uint32_t loop_limiter_variable_id,uint32_t loop_limit_constant_id,const std::vector<protobufs::LoopLimiterInfo> & loop_limiters,uint32_t kill_unreachable_return_value_id,const std::vector<protobufs::AccessChainClampingInfo> & access_chain_clampers)35 TransformationAddFunction::TransformationAddFunction(
36     const std::vector<protobufs::Instruction>& instructions,
37     uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
38     const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
39     uint32_t kill_unreachable_return_value_id,
40     const std::vector<protobufs::AccessChainClampingInfo>&
41         access_chain_clampers) {
42   for (auto& instruction : instructions) {
43     *message_.add_instruction() = instruction;
44   }
45   message_.set_is_livesafe(true);
46   message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
47   message_.set_loop_limit_constant_id(loop_limit_constant_id);
48   for (auto& loop_limiter : loop_limiters) {
49     *message_.add_loop_limiter_info() = loop_limiter;
50   }
51   message_.set_kill_unreachable_return_value_id(
52       kill_unreachable_return_value_id);
53   for (auto& access_clamper : access_chain_clampers) {
54     *message_.add_access_chain_clamping_info() = access_clamper;
55   }
56 }
57 
IsApplicable(opt::IRContext * context,const spvtools::fuzz::FactManager & fact_manager) const58 bool TransformationAddFunction::IsApplicable(
59     opt::IRContext* context,
60     const spvtools::fuzz::FactManager& fact_manager) const {
61   // This transformation may use a lot of ids, all of which need to be fresh
62   // and distinct.  This set tracks them.
63   std::set<uint32_t> ids_used_by_this_transformation;
64 
65   // Ensure that all result ids in the new function are fresh and distinct.
66   for (auto& instruction : message_.instruction()) {
67     if (instruction.result_id()) {
68       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
69               instruction.result_id(), context,
70               &ids_used_by_this_transformation)) {
71         return false;
72       }
73     }
74   }
75 
76   if (message_.is_livesafe()) {
77     // Ensure that all ids provided for making the function livesafe are fresh
78     // and distinct.
79     if (!CheckIdIsFreshAndNotUsedByThisTransformation(
80             message_.loop_limiter_variable_id(), context,
81             &ids_used_by_this_transformation)) {
82       return false;
83     }
84     for (auto& loop_limiter_info : message_.loop_limiter_info()) {
85       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
86               loop_limiter_info.load_id(), context,
87               &ids_used_by_this_transformation)) {
88         return false;
89       }
90       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
91               loop_limiter_info.increment_id(), context,
92               &ids_used_by_this_transformation)) {
93         return false;
94       }
95       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
96               loop_limiter_info.compare_id(), context,
97               &ids_used_by_this_transformation)) {
98         return false;
99       }
100       if (!CheckIdIsFreshAndNotUsedByThisTransformation(
101               loop_limiter_info.logical_op_id(), context,
102               &ids_used_by_this_transformation)) {
103         return false;
104       }
105     }
106     for (auto& access_chain_clamping_info :
107          message_.access_chain_clamping_info()) {
108       for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
109         if (!CheckIdIsFreshAndNotUsedByThisTransformation(
110                 pair.first(), context, &ids_used_by_this_transformation)) {
111           return false;
112         }
113         if (!CheckIdIsFreshAndNotUsedByThisTransformation(
114                 pair.second(), context, &ids_used_by_this_transformation)) {
115           return false;
116         }
117       }
118     }
119   }
120 
121   // Because checking all the conditions for a function to be valid is a big
122   // job that the SPIR-V validator can already do, a "try it and see" approach
123   // is taken here.
124 
125   // We first clone the current module, so that we can try adding the new
126   // function without risking wrecking |context|.
127   auto cloned_module = fuzzerutil::CloneIRContext(context);
128 
129   // We try to add a function to the cloned module, which may fail if
130   // |message_.instruction| is not sufficiently well-formed.
131   if (!TryToAddFunction(cloned_module.get())) {
132     return false;
133   }
134 
135   // Check whether the cloned module is still valid after adding the function.
136   // If it is not, the transformation is not applicable.
137   if (!fuzzerutil::IsValid(cloned_module.get())) {
138     return false;
139   }
140 
141   if (message_.is_livesafe()) {
142     if (!TryToMakeFunctionLivesafe(cloned_module.get(), fact_manager)) {
143       return false;
144     }
145     // After making the function livesafe, we check validity of the module
146     // again.  This is because the turning of OpKill, OpUnreachable and OpReturn
147     // instructions into branches changes control flow graph reachability, which
148     // has the potential to make the module invalid when it was otherwise valid.
149     // It is simpler to rely on the validator to guard against this than to
150     // consider all scenarios when making a function livesafe.
151     if (!fuzzerutil::IsValid(cloned_module.get())) {
152       return false;
153     }
154   }
155   return true;
156 }
157 
Apply(opt::IRContext * context,spvtools::fuzz::FactManager * fact_manager) const158 void TransformationAddFunction::Apply(
159     opt::IRContext* context, spvtools::fuzz::FactManager* fact_manager) const {
160   // Add the function to the module.  As the transformation is applicable, this
161   // should succeed.
162   bool success = TryToAddFunction(context);
163   assert(success && "The function should be successfully added.");
164   (void)(success);  // Keep release builds happy (otherwise they may complain
165                     // that |success| is not used).
166 
167   // Record the fact that all pointer parameters and variables declared in the
168   // function should be regarded as having irrelevant values.  This allows other
169   // passes to store arbitrarily to such variables, and to pass them freely as
170   // parameters to other functions knowing that it is OK if they get
171   // over-written.
172   for (auto& instruction : message_.instruction()) {
173     switch (instruction.opcode()) {
174       case SpvOpFunctionParameter:
175         if (context->get_def_use_mgr()
176                 ->GetDef(instruction.result_type_id())
177                 ->opcode() == SpvOpTypePointer) {
178           fact_manager->AddFactValueOfPointeeIsIrrelevant(
179               instruction.result_id());
180         }
181         break;
182       case SpvOpVariable:
183         fact_manager->AddFactValueOfPointeeIsIrrelevant(
184             instruction.result_id());
185         break;
186       default:
187         break;
188     }
189   }
190 
191   if (message_.is_livesafe()) {
192     // Make the function livesafe, which also should succeed.
193     success = TryToMakeFunctionLivesafe(context, *fact_manager);
194     assert(success && "It should be possible to make the function livesafe.");
195     (void)(success);  // Keep release builds happy.
196 
197     // Inform the fact manager that the function is livesafe.
198     assert(message_.instruction(0).opcode() == SpvOpFunction &&
199            "The first instruction of an 'add function' transformation must be "
200            "OpFunction.");
201     fact_manager->AddFactFunctionIsLivesafe(
202         message_.instruction(0).result_id());
203   } else {
204     // Inform the fact manager that all blocks in the function are dead.
205     for (auto& inst : message_.instruction()) {
206       if (inst.opcode() == SpvOpLabel) {
207         fact_manager->AddFactBlockIsDead(inst.result_id());
208       }
209     }
210   }
211   context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
212 }
213 
ToMessage() const214 protobufs::Transformation TransformationAddFunction::ToMessage() const {
215   protobufs::Transformation result;
216   *result.mutable_add_function() = message_;
217   return result;
218 }
219 
TryToAddFunction(opt::IRContext * context) const220 bool TransformationAddFunction::TryToAddFunction(
221     opt::IRContext* context) const {
222   // This function returns false if |message_.instruction| was not well-formed
223   // enough to actually create a function and add it to |context|.
224 
225   // A function must have at least some instructions.
226   if (message_.instruction().empty()) {
227     return false;
228   }
229 
230   // A function must start with OpFunction.
231   auto function_begin = message_.instruction(0);
232   if (function_begin.opcode() != SpvOpFunction) {
233     return false;
234   }
235 
236   // Make a function, headed by the OpFunction instruction.
237   std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
238       InstructionFromMessage(context, function_begin));
239 
240   // Keeps track of which instruction protobuf message we are currently
241   // considering.
242   uint32_t instruction_index = 1;
243   const auto num_instructions =
244       static_cast<uint32_t>(message_.instruction().size());
245 
246   // Iterate through all function parameter instructions, adding parameters to
247   // the new function.
248   while (instruction_index < num_instructions &&
249          message_.instruction(instruction_index).opcode() ==
250              SpvOpFunctionParameter) {
251     new_function->AddParameter(InstructionFromMessage(
252         context, message_.instruction(instruction_index)));
253     instruction_index++;
254   }
255 
256   // After the parameters, there needs to be a label.
257   if (instruction_index == num_instructions ||
258       message_.instruction(instruction_index).opcode() != SpvOpLabel) {
259     return false;
260   }
261 
262   // Iterate through the instructions block by block until the end of the
263   // function is reached.
264   while (instruction_index < num_instructions &&
265          message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
266     // Invariant: we should always be at a label instruction at this point.
267     assert(message_.instruction(instruction_index).opcode() == SpvOpLabel);
268 
269     // Make a basic block using the label instruction, with the new function
270     // as its parent.
271     std::unique_ptr<opt::BasicBlock> block =
272         MakeUnique<opt::BasicBlock>(InstructionFromMessage(
273             context, message_.instruction(instruction_index)));
274     block->SetParent(new_function.get());
275 
276     // Consider successive instructions until we hit another label or the end
277     // of the function, adding each such instruction to the block.
278     instruction_index++;
279     while (instruction_index < num_instructions &&
280            message_.instruction(instruction_index).opcode() !=
281                SpvOpFunctionEnd &&
282            message_.instruction(instruction_index).opcode() != SpvOpLabel) {
283       block->AddInstruction(InstructionFromMessage(
284           context, message_.instruction(instruction_index)));
285       instruction_index++;
286     }
287     // Add the block to the new function.
288     new_function->AddBasicBlock(std::move(block));
289   }
290   // Having considered all the blocks, we should be at the last instruction and
291   // it needs to be OpFunctionEnd.
292   if (instruction_index != num_instructions - 1 ||
293       message_.instruction(instruction_index).opcode() != SpvOpFunctionEnd) {
294     return false;
295   }
296   // Set the function's final instruction, add the function to the module and
297   // report success.
298   new_function->SetFunctionEnd(
299       InstructionFromMessage(context, message_.instruction(instruction_index)));
300   context->AddFunction(std::move(new_function));
301 
302   context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
303 
304   return true;
305 }
306 
TryToMakeFunctionLivesafe(opt::IRContext * context,const FactManager & fact_manager) const307 bool TransformationAddFunction::TryToMakeFunctionLivesafe(
308     opt::IRContext* context, const FactManager& fact_manager) const {
309   assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
310 
311   // Get a pointer to the added function.
312   opt::Function* added_function = nullptr;
313   for (auto& function : *context->module()) {
314     if (function.result_id() == message_.instruction(0).result_id()) {
315       added_function = &function;
316       break;
317     }
318   }
319   assert(added_function && "The added function should have been found.");
320 
321   if (!TryToAddLoopLimiters(context, added_function)) {
322     // Adding loop limiters did not work; bail out.
323     return false;
324   }
325 
326   // Consider all the instructions in the function, and:
327   // - attempt to replace OpKill and OpUnreachable with return instructions
328   // - attempt to clamp access chains to be within bounds
329   // - check that OpFunctionCall instructions are only to livesafe functions
330   for (auto& block : *added_function) {
331     for (auto& inst : block) {
332       switch (inst.opcode()) {
333         case SpvOpKill:
334         case SpvOpUnreachable:
335           if (!TryToTurnKillOrUnreachableIntoReturn(context, added_function,
336                                                     &inst)) {
337             return false;
338           }
339           break;
340         case SpvOpAccessChain:
341         case SpvOpInBoundsAccessChain:
342           if (!TryToClampAccessChainIndices(context, &inst)) {
343             return false;
344           }
345           break;
346         case SpvOpFunctionCall:
347           // A livesafe function my only call other livesafe functions.
348           if (!fact_manager.FunctionIsLivesafe(
349                   inst.GetSingleWordInOperand(0))) {
350             return false;
351           }
352         default:
353           break;
354       }
355     }
356   }
357   return true;
358 }
359 
TryToAddLoopLimiters(opt::IRContext * context,opt::Function * added_function) const360 bool TransformationAddFunction::TryToAddLoopLimiters(
361     opt::IRContext* context, opt::Function* added_function) const {
362   // Collect up all the loop headers so that we can subsequently add loop
363   // limiting logic.
364   std::vector<opt::BasicBlock*> loop_headers;
365   for (auto& block : *added_function) {
366     if (block.IsLoopHeader()) {
367       loop_headers.push_back(&block);
368     }
369   }
370 
371   if (loop_headers.empty()) {
372     // There are no loops, so no need to add any loop limiters.
373     return true;
374   }
375 
376   // Check that the module contains appropriate ingredients for declaring and
377   // manipulating a loop limiter.
378 
379   auto loop_limit_constant_id_instr =
380       context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
381   if (!loop_limit_constant_id_instr ||
382       loop_limit_constant_id_instr->opcode() != SpvOpConstant) {
383     // The loop limit constant id instruction must exist and have an
384     // appropriate opcode.
385     return false;
386   }
387 
388   auto loop_limit_type = context->get_def_use_mgr()->GetDef(
389       loop_limit_constant_id_instr->type_id());
390   if (loop_limit_type->opcode() != SpvOpTypeInt ||
391       loop_limit_type->GetSingleWordInOperand(0) != 32) {
392     // The type of the loop limit constant must be 32-bit integer.  It
393     // doesn't actually matter whether the integer is signed or not.
394     return false;
395   }
396 
397   // Find the id of the "unsigned int" type.
398   opt::analysis::Integer unsigned_int_type(32, false);
399   uint32_t unsigned_int_type_id =
400       context->get_type_mgr()->GetId(&unsigned_int_type);
401   if (!unsigned_int_type_id) {
402     // Unsigned int is not available; we need this type in order to add loop
403     // limiters.
404     return false;
405   }
406   auto registered_unsigned_int_type =
407       context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
408 
409   // Look for 0 of type unsigned int.
410   opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
411                                   {0});
412   auto registered_zero = context->get_constant_mgr()->FindConstant(&zero);
413   if (!registered_zero) {
414     // We need 0 in order to be able to initialize loop limiters.
415     return false;
416   }
417   uint32_t zero_id = context->get_constant_mgr()
418                          ->GetDefiningInstruction(registered_zero)
419                          ->result_id();
420 
421   // Look for 1 of type unsigned int.
422   opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
423                                  {1});
424   auto registered_one = context->get_constant_mgr()->FindConstant(&one);
425   if (!registered_one) {
426     // We need 1 in order to be able to increment loop limiters.
427     return false;
428   }
429   uint32_t one_id = context->get_constant_mgr()
430                         ->GetDefiningInstruction(registered_one)
431                         ->result_id();
432 
433   // Look for pointer-to-unsigned int type.
434   opt::analysis::Pointer pointer_to_unsigned_int_type(
435       registered_unsigned_int_type, SpvStorageClassFunction);
436   uint32_t pointer_to_unsigned_int_type_id =
437       context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
438   if (!pointer_to_unsigned_int_type_id) {
439     // We need pointer-to-unsigned int in order to declare the loop limiter
440     // variable.
441     return false;
442   }
443 
444   // Look for bool type.
445   opt::analysis::Bool bool_type;
446   uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
447   if (!bool_type_id) {
448     // We need bool in order to compare the loop limiter's value with the loop
449     // limit constant.
450     return false;
451   }
452 
453   // Declare the loop limiter variable at the start of the function's entry
454   // block, via an instruction of the form:
455   //   %loop_limiter_var = SpvOpVariable %ptr_to_uint Function %zero
456   added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
457       context, SpvOpVariable, pointer_to_unsigned_int_type_id,
458       message_.loop_limiter_variable_id(),
459       opt::Instruction::OperandList(
460           {{SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
461            {SPV_OPERAND_TYPE_ID, {zero_id}}})));
462   // Update the module's id bound since we have added the loop limiter
463   // variable id.
464   fuzzerutil::UpdateModuleIdBound(context, message_.loop_limiter_variable_id());
465 
466   // Consider each loop in turn.
467   for (auto loop_header : loop_headers) {
468     // Look for the loop's back-edge block.  This is a predecessor of the loop
469     // header that is dominated by the loop header.
470     uint32_t back_edge_block_id = 0;
471     for (auto pred : context->cfg()->preds(loop_header->id())) {
472       if (context->GetDominatorAnalysis(added_function)
473               ->Dominates(loop_header->id(), pred)) {
474         back_edge_block_id = pred;
475         break;
476       }
477     }
478     if (!back_edge_block_id) {
479       // The loop's back-edge block must be unreachable.  This means that the
480       // loop cannot iterate, so there is no need to make it lifesafe; we can
481       // move on from this loop.
482       continue;
483     }
484     auto back_edge_block = context->cfg()->block(back_edge_block_id);
485 
486     // Go through the sequence of loop limiter infos and find the one
487     // corresponding to this loop.
488     bool found = false;
489     protobufs::LoopLimiterInfo loop_limiter_info;
490     for (auto& info : message_.loop_limiter_info()) {
491       if (info.loop_header_id() == loop_header->id()) {
492         loop_limiter_info = info;
493         found = true;
494         break;
495       }
496     }
497     if (!found) {
498       // We don't have loop limiter info for this loop header.
499       return false;
500     }
501 
502     // The back-edge block either has the form:
503     //
504     // (1)
505     //
506     // %l = OpLabel
507     //      ... instructions ...
508     //      OpBranch %loop_header
509     //
510     // (2)
511     //
512     // %l = OpLabel
513     //      ... instructions ...
514     //      OpBranchConditional %c %loop_header %loop_merge
515     //
516     // (3)
517     //
518     // %l = OpLabel
519     //      ... instructions ...
520     //      OpBranchConditional %c %loop_merge %loop_header
521     //
522     // We turn these into the following:
523     //
524     // (1)
525     //
526     //  %l = OpLabel
527     //       ... instructions ...
528     // %t1 = OpLoad %uint32 %loop_limiter
529     // %t2 = OpIAdd %uint32 %t1 %one
530     //       OpStore %loop_limiter %t2
531     // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
532     //       OpBranchConditional %t3 %loop_merge %loop_header
533     //
534     // (2)
535     //
536     //  %l = OpLabel
537     //       ... instructions ...
538     // %t1 = OpLoad %uint32 %loop_limiter
539     // %t2 = OpIAdd %uint32 %t1 %one
540     //       OpStore %loop_limiter %t2
541     // %t3 = OpULessThan %bool %t1 %loop_limit
542     // %t4 = OpLogicalAnd %bool %c %t3
543     //       OpBranchConditional %t4 %loop_header %loop_merge
544     //
545     // (3)
546     //
547     //  %l = OpLabel
548     //       ... instructions ...
549     // %t1 = OpLoad %uint32 %loop_limiter
550     // %t2 = OpIAdd %uint32 %t1 %one
551     //       OpStore %loop_limiter %t2
552     // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
553     // %t4 = OpLogicalOr %bool %c %t3
554     //       OpBranchConditional %t4 %loop_merge %loop_header
555 
556     auto back_edge_block_terminator = back_edge_block->terminator();
557     bool compare_using_greater_than_equal;
558     if (back_edge_block_terminator->opcode() == SpvOpBranch) {
559       compare_using_greater_than_equal = true;
560     } else {
561       assert(back_edge_block_terminator->opcode() == SpvOpBranchConditional);
562       assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
563                    loop_header->id() &&
564                back_edge_block_terminator->GetSingleWordInOperand(2) ==
565                    loop_header->MergeBlockId()) ||
566               (back_edge_block_terminator->GetSingleWordInOperand(2) ==
567                    loop_header->id() &&
568                back_edge_block_terminator->GetSingleWordInOperand(1) ==
569                    loop_header->MergeBlockId())) &&
570              "A back edge edge block must branch to"
571              " either the loop header or merge");
572       compare_using_greater_than_equal =
573           back_edge_block_terminator->GetSingleWordInOperand(1) ==
574           loop_header->MergeBlockId();
575     }
576 
577     std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
578 
579     // Add a load from the loop limiter variable, of the form:
580     //   %t1 = OpLoad %uint32 %loop_limiter
581     new_instructions.push_back(MakeUnique<opt::Instruction>(
582         context, SpvOpLoad, unsigned_int_type_id, loop_limiter_info.load_id(),
583         opt::Instruction::OperandList(
584             {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
585 
586     // Increment the loaded value:
587     //   %t2 = OpIAdd %uint32 %t1 %one
588     new_instructions.push_back(MakeUnique<opt::Instruction>(
589         context, SpvOpIAdd, unsigned_int_type_id,
590         loop_limiter_info.increment_id(),
591         opt::Instruction::OperandList(
592             {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
593              {SPV_OPERAND_TYPE_ID, {one_id}}})));
594 
595     // Store the incremented value back to the loop limiter variable:
596     //   OpStore %loop_limiter %t2
597     new_instructions.push_back(MakeUnique<opt::Instruction>(
598         context, SpvOpStore, 0, 0,
599         opt::Instruction::OperandList(
600             {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
601              {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
602 
603     // Compare the loaded value with the loop limit; either:
604     //   %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
605     // or
606     //   %t3 = OpULessThan %bool %t1 %loop_limit
607     new_instructions.push_back(MakeUnique<opt::Instruction>(
608         context,
609         compare_using_greater_than_equal ? SpvOpUGreaterThanEqual
610                                          : SpvOpULessThan,
611         bool_type_id, loop_limiter_info.compare_id(),
612         opt::Instruction::OperandList(
613             {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
614              {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
615 
616     if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
617       new_instructions.push_back(MakeUnique<opt::Instruction>(
618           context,
619           compare_using_greater_than_equal ? SpvOpLogicalOr : SpvOpLogicalAnd,
620           bool_type_id, loop_limiter_info.logical_op_id(),
621           opt::Instruction::OperandList(
622               {{SPV_OPERAND_TYPE_ID,
623                 {back_edge_block_terminator->GetSingleWordInOperand(0)}},
624                {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
625     }
626 
627     // Add the new instructions at the end of the back edge block, before the
628     // terminator and any loop merge instruction (as the back edge block can
629     // be the loop header).
630     if (back_edge_block->GetLoopMergeInst()) {
631       back_edge_block->GetLoopMergeInst()->InsertBefore(
632           std::move(new_instructions));
633     } else {
634       back_edge_block_terminator->InsertBefore(std::move(new_instructions));
635     }
636 
637     if (back_edge_block_terminator->opcode() == SpvOpBranchConditional) {
638       back_edge_block_terminator->SetInOperand(
639           0, {loop_limiter_info.logical_op_id()});
640     } else {
641       assert(back_edge_block_terminator->opcode() == SpvOpBranch &&
642              "Back-edge terminator must be OpBranch or OpBranchConditional");
643 
644       // Check that, if the merge block starts with OpPhi instructions, suitable
645       // ids have been provided to give these instructions a value corresponding
646       // to the new incoming edge from the back edge block.
647       auto merge_block = context->cfg()->block(loop_header->MergeBlockId());
648       if (!fuzzerutil::PhiIdsOkForNewEdge(context, back_edge_block, merge_block,
649                                           loop_limiter_info.phi_id())) {
650         return false;
651       }
652 
653       // Augment OpPhi instructions at the loop merge with the given ids.
654       uint32_t phi_index = 0;
655       for (auto& inst : *merge_block) {
656         if (inst.opcode() != SpvOpPhi) {
657           break;
658         }
659         assert(phi_index <
660                    static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
661                "There should be at least one phi id per OpPhi instruction.");
662         inst.AddOperand(
663             {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
664         inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
665         phi_index++;
666       }
667 
668       // Add the new edge, by changing OpBranch to OpBranchConditional.
669       // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3162): This
670       //  could be a problem if the merge block was originally unreachable: it
671       //  might now be dominated by other blocks that it appears earlier than in
672       //  the module.
673       back_edge_block_terminator->SetOpcode(SpvOpBranchConditional);
674       back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
675           {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
676            {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}
677 
678            },
679            {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
680     }
681 
682     // Update the module's id bound with respect to the various ids that
683     // have been used for loop limiter manipulation.
684     fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.load_id());
685     fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.increment_id());
686     fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.compare_id());
687     fuzzerutil::UpdateModuleIdBound(context, loop_limiter_info.logical_op_id());
688   }
689   return true;
690 }
691 
TryToTurnKillOrUnreachableIntoReturn(opt::IRContext * context,opt::Function * added_function,opt::Instruction * kill_or_unreachable_inst) const692 bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
693     opt::IRContext* context, opt::Function* added_function,
694     opt::Instruction* kill_or_unreachable_inst) const {
695   assert((kill_or_unreachable_inst->opcode() == SpvOpKill ||
696           kill_or_unreachable_inst->opcode() == SpvOpUnreachable) &&
697          "Precondition: instruction must be OpKill or OpUnreachable.");
698 
699   // Get the function's return type.
700   auto function_return_type_inst =
701       context->get_def_use_mgr()->GetDef(added_function->type_id());
702 
703   if (function_return_type_inst->opcode() == SpvOpTypeVoid) {
704     // The function has void return type, so change this instruction to
705     // OpReturn.
706     kill_or_unreachable_inst->SetOpcode(SpvOpReturn);
707   } else {
708     // The function has non-void return type, so change this instruction
709     // to OpReturnValue, using the value id provided with the
710     // transformation.
711 
712     // We first check that the id, %id, provided with the transformation
713     // specifically to turn OpKill and OpUnreachable instructions into
714     // OpReturnValue %id has the same type as the function's return type.
715     if (context->get_def_use_mgr()
716             ->GetDef(message_.kill_unreachable_return_value_id())
717             ->type_id() != function_return_type_inst->result_id()) {
718       return false;
719     }
720     kill_or_unreachable_inst->SetOpcode(SpvOpReturnValue);
721     kill_or_unreachable_inst->SetInOperands(
722         {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
723   }
724   return true;
725 }
726 
TryToClampAccessChainIndices(opt::IRContext * context,opt::Instruction * access_chain_inst) const727 bool TransformationAddFunction::TryToClampAccessChainIndices(
728     opt::IRContext* context, opt::Instruction* access_chain_inst) const {
729   assert((access_chain_inst->opcode() == SpvOpAccessChain ||
730           access_chain_inst->opcode() == SpvOpInBoundsAccessChain) &&
731          "Precondition: instruction must be OpAccessChain or "
732          "OpInBoundsAccessChain.");
733 
734   // Find the AccessChainClampingInfo associated with this access chain.
735   const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
736       nullptr;
737   for (auto& clamping_info : message_.access_chain_clamping_info()) {
738     if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
739       access_chain_clamping_info = &clamping_info;
740       break;
741     }
742   }
743   if (!access_chain_clamping_info) {
744     // No access chain clamping information was found; the function cannot be
745     // made livesafe.
746     return false;
747   }
748 
749   // Check that there is a (compare_id, select_id) pair for every
750   // index associated with the instruction.
751   if (static_cast<uint32_t>(
752           access_chain_clamping_info->compare_and_select_ids().size()) !=
753       access_chain_inst->NumInOperands() - 1) {
754     return false;
755   }
756 
757   // Walk the access chain, clamping each index to be within bounds if it is
758   // not a constant.
759   auto base_object = context->get_def_use_mgr()->GetDef(
760       access_chain_inst->GetSingleWordInOperand(0));
761   assert(base_object && "The base object must exist.");
762   auto pointer_type =
763       context->get_def_use_mgr()->GetDef(base_object->type_id());
764   assert(pointer_type && pointer_type->opcode() == SpvOpTypePointer &&
765          "The base object must have pointer type.");
766   auto should_be_composite_type = context->get_def_use_mgr()->GetDef(
767       pointer_type->GetSingleWordInOperand(1));
768 
769   // Consider each index input operand in turn (operand 0 is the base object).
770   for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
771        index++) {
772     // We are going to turn:
773     //
774     // %result = OpAccessChain %type %object ... %index ...
775     //
776     // into:
777     //
778     // %t1 = OpULessThanEqual %bool %index %bound_minus_one
779     // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
780     // %result = OpAccessChain %type %object ... %t2 ...
781     //
782     // ... unless %index is already a constant.
783 
784     // Get the bound for the composite being indexed into; e.g. the number of
785     // columns of matrix or the size of an array.
786     uint32_t bound =
787         GetBoundForCompositeIndex(context, *should_be_composite_type);
788 
789     // Get the instruction associated with the index and figure out its integer
790     // type.
791     const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
792     auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
793     auto index_type_inst =
794         context->get_def_use_mgr()->GetDef(index_inst->type_id());
795     assert(index_type_inst->opcode() == SpvOpTypeInt);
796     assert(index_type_inst->GetSingleWordInOperand(0) == 32);
797     opt::analysis::Integer* index_int_type =
798         context->get_type_mgr()
799             ->GetType(index_type_inst->result_id())
800             ->AsInteger();
801 
802     if (index_inst->opcode() != SpvOpConstant) {
803       // The index is non-constant so we need to clamp it.
804       assert(should_be_composite_type->opcode() != SpvOpTypeStruct &&
805              "Access chain indices into structures are required to be "
806              "constants.");
807       opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
808       if (!context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
809         // We do not have an integer constant whose value is |bound| -1.
810         return false;
811       }
812 
813       opt::analysis::Bool bool_type;
814       uint32_t bool_type_id = context->get_type_mgr()->GetId(&bool_type);
815       if (!bool_type_id) {
816         // Bool type is not declared; we cannot do a comparison.
817         return false;
818       }
819 
820       uint32_t bound_minus_one_id =
821           context->get_constant_mgr()
822               ->GetDefiningInstruction(&bound_minus_one)
823               ->result_id();
824 
825       uint32_t compare_id =
826           access_chain_clamping_info->compare_and_select_ids(index - 1).first();
827       uint32_t select_id =
828           access_chain_clamping_info->compare_and_select_ids(index - 1)
829               .second();
830       std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
831 
832       // Compare the index with the bound via an instruction of the form:
833       //   %t1 = OpULessThanEqual %bool %index %bound_minus_one
834       new_instructions.push_back(MakeUnique<opt::Instruction>(
835           context, SpvOpULessThanEqual, bool_type_id, compare_id,
836           opt::Instruction::OperandList(
837               {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
838                {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
839 
840       // Select the index if in-bounds, otherwise one less than the bound:
841       //   %t2 = OpSelect %int_type %t1 %index %bound_minus_one
842       new_instructions.push_back(MakeUnique<opt::Instruction>(
843           context, SpvOpSelect, index_type_inst->result_id(), select_id,
844           opt::Instruction::OperandList(
845               {{SPV_OPERAND_TYPE_ID, {compare_id}},
846                {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
847                {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
848 
849       // Add the new instructions before the access chain
850       access_chain_inst->InsertBefore(std::move(new_instructions));
851 
852       // Replace %index with %t2.
853       access_chain_inst->SetInOperand(index, {select_id});
854       fuzzerutil::UpdateModuleIdBound(context, compare_id);
855       fuzzerutil::UpdateModuleIdBound(context, select_id);
856     } else {
857       // TODO(afd): At present the SPIR-V spec is not clear on whether
858       //  statically out-of-bounds indices mean that a module is invalid (so
859       //  that it should be rejected by the validator), or that such accesses
860       //  yield undefined results.  Via the following assertion, we assume that
861       //  functions added to the module do not feature statically out-of-bounds
862       //  accesses.
863       // Assert that the index is smaller (unsigned) than this value.
864       // Return false if it is not (to keep compilers happy).
865       if (index_inst->GetSingleWordInOperand(0) >= bound) {
866         assert(false &&
867                "The function has a statically out-of-bounds access; "
868                "this should not occur.");
869         return false;
870       }
871     }
872     should_be_composite_type =
873         FollowCompositeIndex(context, *should_be_composite_type, index_id);
874   }
875   return true;
876 }
877 
GetBoundForCompositeIndex(opt::IRContext * context,const opt::Instruction & composite_type_inst)878 uint32_t TransformationAddFunction::GetBoundForCompositeIndex(
879     opt::IRContext* context, const opt::Instruction& composite_type_inst) {
880   switch (composite_type_inst.opcode()) {
881     case SpvOpTypeArray:
882       return fuzzerutil::GetArraySize(composite_type_inst, context);
883     case SpvOpTypeMatrix:
884     case SpvOpTypeVector:
885       return composite_type_inst.GetSingleWordInOperand(1);
886     case SpvOpTypeStruct: {
887       return fuzzerutil::GetNumberOfStructMembers(composite_type_inst);
888     }
889     default:
890       assert(false && "Unknown composite type.");
891       return 0;
892   }
893 }
894 
FollowCompositeIndex(opt::IRContext * context,const opt::Instruction & composite_type_inst,uint32_t index_id)895 opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
896     opt::IRContext* context, const opt::Instruction& composite_type_inst,
897     uint32_t index_id) {
898   uint32_t sub_object_type_id;
899   switch (composite_type_inst.opcode()) {
900     case SpvOpTypeArray:
901       sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
902       break;
903     case SpvOpTypeMatrix:
904     case SpvOpTypeVector:
905       sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
906       break;
907     case SpvOpTypeStruct: {
908       auto index_inst = context->get_def_use_mgr()->GetDef(index_id);
909       assert(index_inst->opcode() == SpvOpConstant);
910       assert(
911           context->get_def_use_mgr()->GetDef(index_inst->type_id())->opcode() ==
912           SpvOpTypeInt);
913       assert(context->get_def_use_mgr()
914                  ->GetDef(index_inst->type_id())
915                  ->GetSingleWordInOperand(0) == 32);
916       uint32_t index_value = index_inst->GetSingleWordInOperand(0);
917       sub_object_type_id =
918           composite_type_inst.GetSingleWordInOperand(index_value);
919       break;
920     }
921     default:
922       assert(false && "Unknown composite type.");
923       sub_object_type_id = 0;
924       break;
925   }
926   assert(sub_object_type_id && "No sub-object found.");
927   return context->get_def_use_mgr()->GetDef(sub_object_type_id);
928 }
929 
930 }  // namespace fuzz
931 }  // namespace spvtools
932