• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/force_render_red.h"
16 
17 #include "source/fuzz/fact_manager/fact_manager.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 #include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
20 #include "source/fuzz/transformation_context.h"
21 #include "source/fuzz/transformation_replace_constant_with_uniform.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/ir_context.h"
24 #include "source/opt/types.h"
25 #include "source/util/make_unique.h"
26 
27 namespace spvtools {
28 namespace fuzz {
29 
30 namespace {
31 
32 // Helper method to find the fragment shader entry point, complaining if there
33 // is no shader or if there is no fragment entry point.
FindFragmentShaderEntryPoint(opt::IRContext * ir_context,MessageConsumer message_consumer)34 opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context,
35                                             MessageConsumer message_consumer) {
36   // Check that this is a fragment shader
37   bool found_capability_shader = false;
38   for (auto& capability : ir_context->capabilities()) {
39     assert(capability.opcode() == spv::Op::OpCapability);
40     if (spv::Capability(capability.GetSingleWordInOperand(0)) ==
41         spv::Capability::Shader) {
42       found_capability_shader = true;
43       break;
44     }
45   }
46   if (!found_capability_shader) {
47     message_consumer(
48         SPV_MSG_ERROR, nullptr, {},
49         "Forcing of red rendering requires the Shader capability.");
50     return nullptr;
51   }
52 
53   opt::Instruction* fragment_entry_point = nullptr;
54   for (auto& entry_point : ir_context->module()->entry_points()) {
55     if (spv::ExecutionModel(entry_point.GetSingleWordInOperand(0)) ==
56         spv::ExecutionModel::Fragment) {
57       fragment_entry_point = &entry_point;
58       break;
59     }
60   }
61   if (fragment_entry_point == nullptr) {
62     message_consumer(SPV_MSG_ERROR, nullptr, {},
63                      "Forcing of red rendering requires an entry point with "
64                      "the Fragment execution model.");
65     return nullptr;
66   }
67 
68   for (auto& function : *ir_context->module()) {
69     if (function.result_id() ==
70         fragment_entry_point->GetSingleWordInOperand(1)) {
71       return &function;
72     }
73   }
74   assert(
75       false &&
76       "A valid module must have a function associate with each entry point.");
77   return nullptr;
78 }
79 
80 // Helper method to check that there is a single vec4 output variable and get a
81 // pointer to it.
FindVec4OutputVariable(opt::IRContext * ir_context,MessageConsumer message_consumer)82 opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context,
83                                          MessageConsumer message_consumer) {
84   opt::Instruction* output_variable = nullptr;
85   for (auto& inst : ir_context->types_values()) {
86     if (inst.opcode() == spv::Op::OpVariable &&
87         spv::StorageClass(inst.GetSingleWordInOperand(0)) ==
88             spv::StorageClass::Output) {
89       if (output_variable != nullptr) {
90         message_consumer(SPV_MSG_ERROR, nullptr, {},
91                          "Only one output variable can be handled at present; "
92                          "found multiple.");
93         return nullptr;
94       }
95       output_variable = &inst;
96       // Do not break, as we want to check for multiple output variables.
97     }
98   }
99   if (output_variable == nullptr) {
100     message_consumer(SPV_MSG_ERROR, nullptr, {},
101                      "No output variable to which to write red was found.");
102     return nullptr;
103   }
104 
105   auto output_variable_base_type = ir_context->get_type_mgr()
106                                        ->GetType(output_variable->type_id())
107                                        ->AsPointer()
108                                        ->pointee_type()
109                                        ->AsVector();
110   if (!output_variable_base_type ||
111       output_variable_base_type->element_count() != 4 ||
112       !output_variable_base_type->element_type()->AsFloat()) {
113     message_consumer(SPV_MSG_ERROR, nullptr, {},
114                      "The output variable must have type vec4.");
115     return nullptr;
116   }
117 
118   return output_variable;
119 }
120 
121 // Helper to get the ids of float constants 0.0 and 1.0, creating them if
122 // necessary.
FindOrCreateFloatZeroAndOne(opt::IRContext * ir_context,opt::analysis::Float * float_type)123 std::pair<uint32_t, uint32_t> FindOrCreateFloatZeroAndOne(
124     opt::IRContext* ir_context, opt::analysis::Float* float_type) {
125   float one = 1.0;
126   uint32_t one_as_uint;
127   memcpy(&one_as_uint, &one, sizeof(float));
128   std::vector<uint32_t> zero_bytes = {0};
129   std::vector<uint32_t> one_bytes = {one_as_uint};
130   auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant(
131       MakeUnique<opt::analysis::FloatConstant>(float_type, zero_bytes));
132   auto constant_one = ir_context->get_constant_mgr()->RegisterConstant(
133       MakeUnique<opt::analysis::FloatConstant>(float_type, one_bytes));
134   auto constant_zero_id = ir_context->get_constant_mgr()
135                               ->GetDefiningInstruction(constant_zero)
136                               ->result_id();
137   auto constant_one_id = ir_context->get_constant_mgr()
138                              ->GetDefiningInstruction(constant_one)
139                              ->result_id();
140   return std::pair<uint32_t, uint32_t>(constant_zero_id, constant_one_id);
141 }
142 
143 std::unique_ptr<TransformationReplaceConstantWithUniform>
MakeConstantUniformReplacement(opt::IRContext * ir_context,const FactManager & fact_manager,uint32_t constant_id,uint32_t greater_than_instruction,uint32_t in_operand_index)144 MakeConstantUniformReplacement(opt::IRContext* ir_context,
145                                const FactManager& fact_manager,
146                                uint32_t constant_id,
147                                uint32_t greater_than_instruction,
148                                uint32_t in_operand_index) {
149   return MakeUnique<TransformationReplaceConstantWithUniform>(
150       MakeIdUseDescriptor(
151           constant_id,
152           MakeInstructionDescriptor(greater_than_instruction,
153                                     spv::Op::OpFOrdGreaterThan, 0),
154           in_operand_index),
155       fact_manager.GetUniformDescriptorsForConstant(constant_id)[0],
156       ir_context->TakeNextId(), ir_context->TakeNextId());
157 }
158 
159 }  // namespace
160 
ForceRenderRed(const spv_target_env & target_env,spv_validator_options validator_options,const std::vector<uint32_t> & binary_in,const spvtools::fuzz::protobufs::FactSequence & initial_facts,const MessageConsumer & message_consumer,std::vector<uint32_t> * binary_out)161 bool ForceRenderRed(
162     const spv_target_env& target_env, spv_validator_options validator_options,
163     const std::vector<uint32_t>& binary_in,
164     const spvtools::fuzz::protobufs::FactSequence& initial_facts,
165     const MessageConsumer& message_consumer,
166     std::vector<uint32_t>* binary_out) {
167   spvtools::SpirvTools tools(target_env);
168   if (!tools.IsValid()) {
169     message_consumer(SPV_MSG_ERROR, nullptr, {},
170                      "Failed to create SPIRV-Tools interface; stopping.");
171     return false;
172   }
173 
174   // Initial binary should be valid.
175   if (!tools.Validate(&binary_in[0], binary_in.size(), validator_options)) {
176     message_consumer(SPV_MSG_ERROR, nullptr, {},
177                      "Initial binary is invalid; stopping.");
178     return false;
179   }
180 
181   // Build the module from the input binary.
182   std::unique_ptr<opt::IRContext> ir_context = BuildModule(
183       target_env, message_consumer, binary_in.data(), binary_in.size());
184   assert(ir_context);
185 
186   // Set up a fact manager with any given initial facts.
187   TransformationContext transformation_context(
188       MakeUnique<FactManager>(ir_context.get()), validator_options);
189   for (auto& fact : initial_facts.fact()) {
190     transformation_context.GetFactManager()->MaybeAddFact(fact);
191   }
192 
193   auto entry_point_function =
194       FindFragmentShaderEntryPoint(ir_context.get(), message_consumer);
195   auto output_variable =
196       FindVec4OutputVariable(ir_context.get(), message_consumer);
197   if (entry_point_function == nullptr || output_variable == nullptr) {
198     return false;
199   }
200 
201   opt::analysis::Float temp_float_type(32);
202   opt::analysis::Float* float_type = ir_context->get_type_mgr()
203                                          ->GetRegisteredType(&temp_float_type)
204                                          ->AsFloat();
205   std::pair<uint32_t, uint32_t> zero_one_float_ids =
206       FindOrCreateFloatZeroAndOne(ir_context.get(), float_type);
207 
208   // Make the new exit block
209   auto new_exit_block_id = ir_context->TakeNextId();
210   {
211     auto label = MakeUnique<opt::Instruction>(
212         ir_context.get(), spv::Op::OpLabel, 0, new_exit_block_id,
213         opt::Instruction::OperandList());
214     auto new_exit_block = MakeUnique<opt::BasicBlock>(std::move(label));
215     new_exit_block->AddInstruction(
216         MakeUnique<opt::Instruction>(ir_context.get(), spv::Op::OpReturn, 0, 0,
217                                      opt::Instruction::OperandList()));
218     entry_point_function->AddBasicBlock(std::move(new_exit_block));
219   }
220 
221   // Make the new entry block
222   {
223     auto label = MakeUnique<opt::Instruction>(
224         ir_context.get(), spv::Op::OpLabel, 0, ir_context->TakeNextId(),
225         opt::Instruction::OperandList());
226     auto new_entry_block = MakeUnique<opt::BasicBlock>(std::move(label));
227 
228     // Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing
229     // the colour red.
230     opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}};
231     opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}};
232     opt::Instruction::OperandList op_composite_construct_operands = {
233         one_float, zero_float, zero_float, one_float};
234     auto temp_vec4 = opt::analysis::Vector(float_type, 4);
235     auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4);
236     auto red = MakeUnique<opt::Instruction>(
237         ir_context.get(), spv::Op::OpCompositeConstruct, vec4_id,
238         ir_context->TakeNextId(), op_composite_construct_operands);
239     auto red_id = red->result_id();
240     new_entry_block->AddInstruction(std::move(red));
241 
242     // Make an instruction to store red into the output color.
243     opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID,
244                                            {output_variable->result_id()}};
245     opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}};
246     opt::Instruction::OperandList op_store_operands = {variable_to_store_into,
247                                                        value_to_be_stored};
248     new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
249         ir_context.get(), spv::Op::OpStore, 0, 0, op_store_operands));
250 
251     // We are going to attempt to construct 'false' as an expression of the form
252     // 'literal1 > literal2'. If we succeed, we will later replace each literal
253     // with a uniform of the same value - we can only do that replacement once
254     // we have added the entry block to the module.
255     std::unique_ptr<TransformationReplaceConstantWithUniform>
256         first_greater_then_operand_replacement = nullptr;
257     std::unique_ptr<TransformationReplaceConstantWithUniform>
258         second_greater_then_operand_replacement = nullptr;
259     uint32_t id_guaranteed_to_be_false = 0;
260 
261     opt::analysis::Bool temp_bool_type;
262     opt::analysis::Bool* registered_bool_type =
263         ir_context->get_type_mgr()
264             ->GetRegisteredType(&temp_bool_type)
265             ->AsBool();
266 
267     auto float_type_id = ir_context->get_type_mgr()->GetId(float_type);
268     auto types_for_which_uniforms_are_known =
269         transformation_context.GetFactManager()
270             ->GetTypesForWhichUniformValuesAreKnown();
271 
272     // Check whether we have any float uniforms.
273     if (std::find(types_for_which_uniforms_are_known.begin(),
274                   types_for_which_uniforms_are_known.end(),
275                   float_type_id) != types_for_which_uniforms_are_known.end()) {
276       // We have at least one float uniform; let's see whether we have at least
277       // two.
278       auto available_constants =
279           transformation_context.GetFactManager()
280               ->GetConstantsAvailableFromUniformsForType(float_type_id);
281       if (available_constants.size() > 1) {
282         // Grab the float constants associated with the first two known float
283         // uniforms.
284         auto first_constant =
285             ir_context->get_constant_mgr()
286                 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
287                     available_constants[0]))
288                 ->AsFloatConstant();
289         auto second_constant =
290             ir_context->get_constant_mgr()
291                 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
292                     available_constants[1]))
293                 ->AsFloatConstant();
294 
295         // Now work out which of the two constants is larger than the other.
296         uint32_t larger_constant_index = 0;
297         uint32_t smaller_constant_index = 0;
298         if (first_constant->GetFloat() > second_constant->GetFloat()) {
299           larger_constant_index = 0;
300           smaller_constant_index = 1;
301         } else if (first_constant->GetFloat() < second_constant->GetFloat()) {
302           larger_constant_index = 1;
303           smaller_constant_index = 0;
304         }
305 
306         // Only proceed with these constants if they have turned out to be
307         // distinct.
308         if (larger_constant_index != smaller_constant_index) {
309           // We are in a position to create 'false' as 'literal1 > literal2', so
310           // reserve an id for this computation; this id will end up being
311           // guaranteed to be 'false'.
312           id_guaranteed_to_be_false = ir_context->TakeNextId();
313 
314           auto smaller_constant = available_constants[smaller_constant_index];
315           auto larger_constant = available_constants[larger_constant_index];
316 
317           opt::Instruction::OperandList greater_than_operands = {
318               {SPV_OPERAND_TYPE_ID, {smaller_constant}},
319               {SPV_OPERAND_TYPE_ID, {larger_constant}}};
320           new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
321               ir_context.get(), spv::Op::OpFOrdGreaterThan,
322               ir_context->get_type_mgr()->GetId(registered_bool_type),
323               id_guaranteed_to_be_false, greater_than_operands));
324 
325           first_greater_then_operand_replacement =
326               MakeConstantUniformReplacement(
327                   ir_context.get(), *transformation_context.GetFactManager(),
328                   smaller_constant, id_guaranteed_to_be_false, 0);
329           second_greater_then_operand_replacement =
330               MakeConstantUniformReplacement(
331                   ir_context.get(), *transformation_context.GetFactManager(),
332                   larger_constant, id_guaranteed_to_be_false, 1);
333         }
334       }
335     }
336 
337     if (id_guaranteed_to_be_false == 0) {
338       auto constant_false = ir_context->get_constant_mgr()->RegisterConstant(
339           MakeUnique<opt::analysis::BoolConstant>(registered_bool_type, false));
340       id_guaranteed_to_be_false = ir_context->get_constant_mgr()
341                                       ->GetDefiningInstruction(constant_false)
342                                       ->result_id();
343     }
344 
345     opt::Operand false_condition = {SPV_OPERAND_TYPE_ID,
346                                     {id_guaranteed_to_be_false}};
347     opt::Operand then_block = {SPV_OPERAND_TYPE_ID,
348                                {entry_point_function->entry()->id()}};
349     opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}};
350     opt::Instruction::OperandList op_branch_conditional_operands = {
351         false_condition, then_block, else_block};
352     new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
353         ir_context.get(), spv::Op::OpBranchConditional, 0, 0,
354         op_branch_conditional_operands));
355 
356     entry_point_function->InsertBasicBlockBefore(
357         std::move(new_entry_block), entry_point_function->entry().get());
358 
359     for (auto& replacement : {first_greater_then_operand_replacement.get(),
360                               second_greater_then_operand_replacement.get()}) {
361       if (replacement) {
362         assert(replacement->IsApplicable(ir_context.get(),
363                                          transformation_context));
364         replacement->Apply(ir_context.get(), &transformation_context);
365       }
366     }
367   }
368 
369   // Write out the module as a binary.
370   ir_context->module()->ToBinary(binary_out, false);
371   return true;
372 }
373 
374 }  // namespace fuzz
375 }  // namespace spvtools
376