• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/force_render_red.h"
16 
17 #include "source/fuzz/fact_manager.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 #include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
20 #include "source/fuzz/transformation_replace_constant_with_uniform.h"
21 #include "source/fuzz/uniform_buffer_element_descriptor.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/ir_context.h"
24 #include "source/opt/types.h"
25 #include "source/util/make_unique.h"
26 #include "tools/util/cli_consumer.h"
27 
28 #include <algorithm>
29 #include <utility>
30 
31 namespace spvtools {
32 namespace fuzz {
33 
34 namespace {
35 
36 // Helper method to find the fragment shader entry point, complaining if there
37 // is no shader or if there is no fragment entry point.
FindFragmentShaderEntryPoint(opt::IRContext * ir_context,MessageConsumer message_consumer)38 opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context,
39                                             MessageConsumer message_consumer) {
40   // Check that this is a fragment shader
41   bool found_capability_shader = false;
42   for (auto& capability : ir_context->capabilities()) {
43     assert(capability.opcode() == SpvOpCapability);
44     if (capability.GetSingleWordInOperand(0) == SpvCapabilityShader) {
45       found_capability_shader = true;
46       break;
47     }
48   }
49   if (!found_capability_shader) {
50     message_consumer(
51         SPV_MSG_ERROR, nullptr, {},
52         "Forcing of red rendering requires the Shader capability.");
53     return nullptr;
54   }
55 
56   opt::Instruction* fragment_entry_point = nullptr;
57   for (auto& entry_point : ir_context->module()->entry_points()) {
58     if (entry_point.GetSingleWordInOperand(0) == SpvExecutionModelFragment) {
59       fragment_entry_point = &entry_point;
60       break;
61     }
62   }
63   if (fragment_entry_point == nullptr) {
64     message_consumer(SPV_MSG_ERROR, nullptr, {},
65                      "Forcing of red rendering requires an entry point with "
66                      "the Fragment execution model.");
67     return nullptr;
68   }
69 
70   for (auto& function : *ir_context->module()) {
71     if (function.result_id() ==
72         fragment_entry_point->GetSingleWordInOperand(1)) {
73       return &function;
74     }
75   }
76   assert(
77       false &&
78       "A valid module must have a function associate with each entry point.");
79   return nullptr;
80 }
81 
82 // Helper method to check that there is a single vec4 output variable and get a
83 // pointer to it.
FindVec4OutputVariable(opt::IRContext * ir_context,MessageConsumer message_consumer)84 opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context,
85                                          MessageConsumer message_consumer) {
86   opt::Instruction* output_variable = nullptr;
87   for (auto& inst : ir_context->types_values()) {
88     if (inst.opcode() == SpvOpVariable &&
89         inst.GetSingleWordInOperand(0) == SpvStorageClassOutput) {
90       if (output_variable != nullptr) {
91         message_consumer(SPV_MSG_ERROR, nullptr, {},
92                          "Only one output variable can be handled at present; "
93                          "found multiple.");
94         return nullptr;
95       }
96       output_variable = &inst;
97       // Do not break, as we want to check for multiple output variables.
98     }
99   }
100   if (output_variable == nullptr) {
101     message_consumer(SPV_MSG_ERROR, nullptr, {},
102                      "No output variable to which to write red was found.");
103     return nullptr;
104   }
105 
106   auto output_variable_base_type = ir_context->get_type_mgr()
107                                        ->GetType(output_variable->type_id())
108                                        ->AsPointer()
109                                        ->pointee_type()
110                                        ->AsVector();
111   if (!output_variable_base_type ||
112       output_variable_base_type->element_count() != 4 ||
113       !output_variable_base_type->element_type()->AsFloat()) {
114     message_consumer(SPV_MSG_ERROR, nullptr, {},
115                      "The output variable must have type vec4.");
116     return nullptr;
117   }
118 
119   return output_variable;
120 }
121 
122 // Helper to get the ids of float constants 0.0 and 1.0, creating them if
123 // necessary.
FindOrCreateFloatZeroAndOne(opt::IRContext * ir_context,opt::analysis::Float * float_type)124 std::pair<uint32_t, uint32_t> FindOrCreateFloatZeroAndOne(
125     opt::IRContext* ir_context, opt::analysis::Float* float_type) {
126   float one = 1.0;
127   uint32_t one_as_uint;
128   memcpy(&one_as_uint, &one, sizeof(float));
129   std::vector<uint32_t> zero_bytes = {0};
130   std::vector<uint32_t> one_bytes = {one_as_uint};
131   auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant(
132       MakeUnique<opt::analysis::FloatConstant>(float_type, zero_bytes));
133   auto constant_one = ir_context->get_constant_mgr()->RegisterConstant(
134       MakeUnique<opt::analysis::FloatConstant>(float_type, one_bytes));
135   auto constant_zero_id = ir_context->get_constant_mgr()
136                               ->GetDefiningInstruction(constant_zero)
137                               ->result_id();
138   auto constant_one_id = ir_context->get_constant_mgr()
139                              ->GetDefiningInstruction(constant_one)
140                              ->result_id();
141   return std::pair<uint32_t, uint32_t>(constant_zero_id, constant_one_id);
142 }
143 
144 std::unique_ptr<TransformationReplaceConstantWithUniform>
MakeConstantUniformReplacement(opt::IRContext * ir_context,const FactManager & fact_manager,uint32_t constant_id,uint32_t greater_than_instruction,uint32_t in_operand_index)145 MakeConstantUniformReplacement(opt::IRContext* ir_context,
146                                const FactManager& fact_manager,
147                                uint32_t constant_id,
148                                uint32_t greater_than_instruction,
149                                uint32_t in_operand_index) {
150   return MakeUnique<TransformationReplaceConstantWithUniform>(
151       MakeIdUseDescriptor(constant_id,
152                           MakeInstructionDescriptor(greater_than_instruction,
153                                                     SpvOpFOrdGreaterThan, 0),
154                           in_operand_index),
155       fact_manager.GetUniformDescriptorsForConstant(ir_context, constant_id)[0],
156       ir_context->TakeNextId(), ir_context->TakeNextId());
157 }
158 
159 }  // namespace
160 
ForceRenderRed(const spv_target_env & target_env,const std::vector<uint32_t> & binary_in,const spvtools::fuzz::protobufs::FactSequence & initial_facts,std::vector<uint32_t> * binary_out)161 bool ForceRenderRed(
162     const spv_target_env& target_env, const std::vector<uint32_t>& binary_in,
163     const spvtools::fuzz::protobufs::FactSequence& initial_facts,
164     std::vector<uint32_t>* binary_out) {
165   auto message_consumer = spvtools::utils::CLIMessageConsumer;
166   spvtools::SpirvTools tools(target_env);
167   if (!tools.IsValid()) {
168     message_consumer(SPV_MSG_ERROR, nullptr, {},
169                      "Failed to create SPIRV-Tools interface; stopping.");
170     return false;
171   }
172 
173   // Initial binary should be valid.
174   if (!tools.Validate(&binary_in[0], binary_in.size())) {
175     message_consumer(SPV_MSG_ERROR, nullptr, {},
176                      "Initial binary is invalid; stopping.");
177     return false;
178   }
179 
180   // Build the module from the input binary.
181   std::unique_ptr<opt::IRContext> ir_context = BuildModule(
182       target_env, message_consumer, binary_in.data(), binary_in.size());
183   assert(ir_context);
184 
185   // Set up a fact manager with any given initial facts.
186   FactManager fact_manager;
187   for (auto& fact : initial_facts.fact()) {
188     fact_manager.AddFact(fact, ir_context.get());
189   }
190 
191   auto entry_point_function =
192       FindFragmentShaderEntryPoint(ir_context.get(), message_consumer);
193   auto output_variable =
194       FindVec4OutputVariable(ir_context.get(), message_consumer);
195   if (entry_point_function == nullptr || output_variable == nullptr) {
196     return false;
197   }
198 
199   opt::analysis::Float temp_float_type(32);
200   opt::analysis::Float* float_type = ir_context->get_type_mgr()
201                                          ->GetRegisteredType(&temp_float_type)
202                                          ->AsFloat();
203   std::pair<uint32_t, uint32_t> zero_one_float_ids =
204       FindOrCreateFloatZeroAndOne(ir_context.get(), float_type);
205 
206   // Make the new exit block
207   auto new_exit_block_id = ir_context->TakeNextId();
208   {
209     auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
210                                               new_exit_block_id,
211                                               opt::Instruction::OperandList());
212     auto new_exit_block = MakeUnique<opt::BasicBlock>(std::move(label));
213     new_exit_block->AddInstruction(MakeUnique<opt::Instruction>(
214         ir_context.get(), SpvOpReturn, 0, 0, opt::Instruction::OperandList()));
215     entry_point_function->AddBasicBlock(std::move(new_exit_block));
216   }
217 
218   // Make the new entry block
219   {
220     auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
221                                               ir_context->TakeNextId(),
222                                               opt::Instruction::OperandList());
223     auto new_entry_block = MakeUnique<opt::BasicBlock>(std::move(label));
224 
225     // Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing
226     // the colour red.
227     opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}};
228     opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}};
229     opt::Instruction::OperandList op_composite_construct_operands = {
230         one_float, zero_float, zero_float, one_float};
231     auto temp_vec4 = opt::analysis::Vector(float_type, 4);
232     auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4);
233     auto red = MakeUnique<opt::Instruction>(
234         ir_context.get(), SpvOpCompositeConstruct, vec4_id,
235         ir_context->TakeNextId(), op_composite_construct_operands);
236     auto red_id = red->result_id();
237     new_entry_block->AddInstruction(std::move(red));
238 
239     // Make an instruction to store red into the output color.
240     opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID,
241                                            {output_variable->result_id()}};
242     opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}};
243     opt::Instruction::OperandList op_store_operands = {variable_to_store_into,
244                                                        value_to_be_stored};
245     new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
246         ir_context.get(), SpvOpStore, 0, 0, op_store_operands));
247 
248     // We are going to attempt to construct 'false' as an expression of the form
249     // 'literal1 > literal2'. If we succeed, we will later replace each literal
250     // with a uniform of the same value - we can only do that replacement once
251     // we have added the entry block to the module.
252     std::unique_ptr<TransformationReplaceConstantWithUniform>
253         first_greater_then_operand_replacement = nullptr;
254     std::unique_ptr<TransformationReplaceConstantWithUniform>
255         second_greater_then_operand_replacement = nullptr;
256     uint32_t id_guaranteed_to_be_false = 0;
257 
258     opt::analysis::Bool temp_bool_type;
259     opt::analysis::Bool* registered_bool_type =
260         ir_context->get_type_mgr()
261             ->GetRegisteredType(&temp_bool_type)
262             ->AsBool();
263 
264     auto float_type_id = ir_context->get_type_mgr()->GetId(float_type);
265     auto types_for_which_uniforms_are_known =
266         fact_manager.GetTypesForWhichUniformValuesAreKnown();
267 
268     // Check whether we have any float uniforms.
269     if (std::find(types_for_which_uniforms_are_known.begin(),
270                   types_for_which_uniforms_are_known.end(),
271                   float_type_id) != types_for_which_uniforms_are_known.end()) {
272       // We have at least one float uniform; let's see whether we have at least
273       // two.
274       auto available_constants =
275           fact_manager.GetConstantsAvailableFromUniformsForType(
276               ir_context.get(), float_type_id);
277       if (available_constants.size() > 1) {
278         // Grab the float constants associated with the first two known float
279         // uniforms.
280         auto first_constant =
281             ir_context->get_constant_mgr()
282                 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
283                     available_constants[0]))
284                 ->AsFloatConstant();
285         auto second_constant =
286             ir_context->get_constant_mgr()
287                 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
288                     available_constants[1]))
289                 ->AsFloatConstant();
290 
291         // Now work out which of the two constants is larger than the other.
292         uint32_t larger_constant_index = 0;
293         uint32_t smaller_constant_index = 0;
294         if (first_constant->GetFloat() > second_constant->GetFloat()) {
295           larger_constant_index = 0;
296           smaller_constant_index = 1;
297         } else if (first_constant->GetFloat() < second_constant->GetFloat()) {
298           larger_constant_index = 1;
299           smaller_constant_index = 0;
300         }
301 
302         // Only proceed with these constants if they have turned out to be
303         // distinct.
304         if (larger_constant_index != smaller_constant_index) {
305           // We are in a position to create 'false' as 'literal1 > literal2', so
306           // reserve an id for this computation; this id will end up being
307           // guaranteed to be 'false'.
308           id_guaranteed_to_be_false = ir_context->TakeNextId();
309 
310           auto smaller_constant = available_constants[smaller_constant_index];
311           auto larger_constant = available_constants[larger_constant_index];
312 
313           opt::Instruction::OperandList greater_than_operands = {
314               {SPV_OPERAND_TYPE_ID, {smaller_constant}},
315               {SPV_OPERAND_TYPE_ID, {larger_constant}}};
316           new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
317               ir_context.get(), SpvOpFOrdGreaterThan,
318               ir_context->get_type_mgr()->GetId(registered_bool_type),
319               id_guaranteed_to_be_false, greater_than_operands));
320 
321           first_greater_then_operand_replacement =
322               MakeConstantUniformReplacement(ir_context.get(), fact_manager,
323                                              smaller_constant,
324                                              id_guaranteed_to_be_false, 0);
325           second_greater_then_operand_replacement =
326               MakeConstantUniformReplacement(ir_context.get(), fact_manager,
327                                              larger_constant,
328                                              id_guaranteed_to_be_false, 1);
329         }
330       }
331     }
332 
333     if (id_guaranteed_to_be_false == 0) {
334       auto constant_false = ir_context->get_constant_mgr()->RegisterConstant(
335           MakeUnique<opt::analysis::BoolConstant>(registered_bool_type, false));
336       id_guaranteed_to_be_false = ir_context->get_constant_mgr()
337                                       ->GetDefiningInstruction(constant_false)
338                                       ->result_id();
339     }
340 
341     opt::Operand false_condition = {SPV_OPERAND_TYPE_ID,
342                                     {id_guaranteed_to_be_false}};
343     opt::Operand then_block = {SPV_OPERAND_TYPE_ID,
344                                {entry_point_function->entry()->id()}};
345     opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}};
346     opt::Instruction::OperandList op_branch_conditional_operands = {
347         false_condition, then_block, else_block};
348     new_entry_block->AddInstruction(
349         MakeUnique<opt::Instruction>(ir_context.get(), SpvOpBranchConditional,
350                                      0, 0, op_branch_conditional_operands));
351 
352     entry_point_function->InsertBasicBlockBefore(
353         std::move(new_entry_block), entry_point_function->entry().get());
354 
355     for (auto& replacement : {first_greater_then_operand_replacement.get(),
356                               second_greater_then_operand_replacement.get()}) {
357       if (replacement) {
358         assert(replacement->IsApplicable(ir_context.get(), fact_manager));
359         replacement->Apply(ir_context.get(), &fact_manager);
360       }
361     }
362   }
363 
364   // Write out the module as a binary.
365   ir_context->module()->ToBinary(binary_out, false);
366   return true;
367 }
368 
369 }  // namespace fuzz
370 }  // namespace spvtools
371