1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/force_render_red.h"
16
17 #include "source/fuzz/fact_manager.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 #include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
20 #include "source/fuzz/transformation_context.h"
21 #include "source/fuzz/transformation_replace_constant_with_uniform.h"
22 #include "source/fuzz/uniform_buffer_element_descriptor.h"
23 #include "source/opt/build_module.h"
24 #include "source/opt/ir_context.h"
25 #include "source/opt/types.h"
26 #include "source/util/make_unique.h"
27 #include "tools/util/cli_consumer.h"
28
29 #include <algorithm>
30 #include <utility>
31
32 namespace spvtools {
33 namespace fuzz {
34
35 namespace {
36
37 // Helper method to find the fragment shader entry point, complaining if there
38 // is no shader or if there is no fragment entry point.
FindFragmentShaderEntryPoint(opt::IRContext * ir_context,MessageConsumer message_consumer)39 opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context,
40 MessageConsumer message_consumer) {
41 // Check that this is a fragment shader
42 bool found_capability_shader = false;
43 for (auto& capability : ir_context->capabilities()) {
44 assert(capability.opcode() == SpvOpCapability);
45 if (capability.GetSingleWordInOperand(0) == SpvCapabilityShader) {
46 found_capability_shader = true;
47 break;
48 }
49 }
50 if (!found_capability_shader) {
51 message_consumer(
52 SPV_MSG_ERROR, nullptr, {},
53 "Forcing of red rendering requires the Shader capability.");
54 return nullptr;
55 }
56
57 opt::Instruction* fragment_entry_point = nullptr;
58 for (auto& entry_point : ir_context->module()->entry_points()) {
59 if (entry_point.GetSingleWordInOperand(0) == SpvExecutionModelFragment) {
60 fragment_entry_point = &entry_point;
61 break;
62 }
63 }
64 if (fragment_entry_point == nullptr) {
65 message_consumer(SPV_MSG_ERROR, nullptr, {},
66 "Forcing of red rendering requires an entry point with "
67 "the Fragment execution model.");
68 return nullptr;
69 }
70
71 for (auto& function : *ir_context->module()) {
72 if (function.result_id() ==
73 fragment_entry_point->GetSingleWordInOperand(1)) {
74 return &function;
75 }
76 }
77 assert(
78 false &&
79 "A valid module must have a function associate with each entry point.");
80 return nullptr;
81 }
82
83 // Helper method to check that there is a single vec4 output variable and get a
84 // pointer to it.
FindVec4OutputVariable(opt::IRContext * ir_context,MessageConsumer message_consumer)85 opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context,
86 MessageConsumer message_consumer) {
87 opt::Instruction* output_variable = nullptr;
88 for (auto& inst : ir_context->types_values()) {
89 if (inst.opcode() == SpvOpVariable &&
90 inst.GetSingleWordInOperand(0) == SpvStorageClassOutput) {
91 if (output_variable != nullptr) {
92 message_consumer(SPV_MSG_ERROR, nullptr, {},
93 "Only one output variable can be handled at present; "
94 "found multiple.");
95 return nullptr;
96 }
97 output_variable = &inst;
98 // Do not break, as we want to check for multiple output variables.
99 }
100 }
101 if (output_variable == nullptr) {
102 message_consumer(SPV_MSG_ERROR, nullptr, {},
103 "No output variable to which to write red was found.");
104 return nullptr;
105 }
106
107 auto output_variable_base_type = ir_context->get_type_mgr()
108 ->GetType(output_variable->type_id())
109 ->AsPointer()
110 ->pointee_type()
111 ->AsVector();
112 if (!output_variable_base_type ||
113 output_variable_base_type->element_count() != 4 ||
114 !output_variable_base_type->element_type()->AsFloat()) {
115 message_consumer(SPV_MSG_ERROR, nullptr, {},
116 "The output variable must have type vec4.");
117 return nullptr;
118 }
119
120 return output_variable;
121 }
122
123 // Helper to get the ids of float constants 0.0 and 1.0, creating them if
124 // necessary.
FindOrCreateFloatZeroAndOne(opt::IRContext * ir_context,opt::analysis::Float * float_type)125 std::pair<uint32_t, uint32_t> FindOrCreateFloatZeroAndOne(
126 opt::IRContext* ir_context, opt::analysis::Float* float_type) {
127 float one = 1.0;
128 uint32_t one_as_uint;
129 memcpy(&one_as_uint, &one, sizeof(float));
130 std::vector<uint32_t> zero_bytes = {0};
131 std::vector<uint32_t> one_bytes = {one_as_uint};
132 auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant(
133 MakeUnique<opt::analysis::FloatConstant>(float_type, zero_bytes));
134 auto constant_one = ir_context->get_constant_mgr()->RegisterConstant(
135 MakeUnique<opt::analysis::FloatConstant>(float_type, one_bytes));
136 auto constant_zero_id = ir_context->get_constant_mgr()
137 ->GetDefiningInstruction(constant_zero)
138 ->result_id();
139 auto constant_one_id = ir_context->get_constant_mgr()
140 ->GetDefiningInstruction(constant_one)
141 ->result_id();
142 return std::pair<uint32_t, uint32_t>(constant_zero_id, constant_one_id);
143 }
144
145 std::unique_ptr<TransformationReplaceConstantWithUniform>
MakeConstantUniformReplacement(opt::IRContext * ir_context,const FactManager & fact_manager,uint32_t constant_id,uint32_t greater_than_instruction,uint32_t in_operand_index)146 MakeConstantUniformReplacement(opt::IRContext* ir_context,
147 const FactManager& fact_manager,
148 uint32_t constant_id,
149 uint32_t greater_than_instruction,
150 uint32_t in_operand_index) {
151 return MakeUnique<TransformationReplaceConstantWithUniform>(
152 MakeIdUseDescriptor(constant_id,
153 MakeInstructionDescriptor(greater_than_instruction,
154 SpvOpFOrdGreaterThan, 0),
155 in_operand_index),
156 fact_manager.GetUniformDescriptorsForConstant(ir_context, constant_id)[0],
157 ir_context->TakeNextId(), ir_context->TakeNextId());
158 }
159
160 } // namespace
161
ForceRenderRed(const spv_target_env & target_env,spv_validator_options validator_options,const std::vector<uint32_t> & binary_in,const spvtools::fuzz::protobufs::FactSequence & initial_facts,std::vector<uint32_t> * binary_out)162 bool ForceRenderRed(
163 const spv_target_env& target_env, spv_validator_options validator_options,
164 const std::vector<uint32_t>& binary_in,
165 const spvtools::fuzz::protobufs::FactSequence& initial_facts,
166 std::vector<uint32_t>* binary_out) {
167 auto message_consumer = spvtools::utils::CLIMessageConsumer;
168 spvtools::SpirvTools tools(target_env);
169 if (!tools.IsValid()) {
170 message_consumer(SPV_MSG_ERROR, nullptr, {},
171 "Failed to create SPIRV-Tools interface; stopping.");
172 return false;
173 }
174
175 // Initial binary should be valid.
176 if (!tools.Validate(&binary_in[0], binary_in.size(), validator_options)) {
177 message_consumer(SPV_MSG_ERROR, nullptr, {},
178 "Initial binary is invalid; stopping.");
179 return false;
180 }
181
182 // Build the module from the input binary.
183 std::unique_ptr<opt::IRContext> ir_context = BuildModule(
184 target_env, message_consumer, binary_in.data(), binary_in.size());
185 assert(ir_context);
186
187 // Set up a fact manager with any given initial facts.
188 FactManager fact_manager;
189 for (auto& fact : initial_facts.fact()) {
190 fact_manager.AddFact(fact, ir_context.get());
191 }
192 TransformationContext transformation_context(&fact_manager,
193 validator_options);
194
195 auto entry_point_function =
196 FindFragmentShaderEntryPoint(ir_context.get(), message_consumer);
197 auto output_variable =
198 FindVec4OutputVariable(ir_context.get(), message_consumer);
199 if (entry_point_function == nullptr || output_variable == nullptr) {
200 return false;
201 }
202
203 opt::analysis::Float temp_float_type(32);
204 opt::analysis::Float* float_type = ir_context->get_type_mgr()
205 ->GetRegisteredType(&temp_float_type)
206 ->AsFloat();
207 std::pair<uint32_t, uint32_t> zero_one_float_ids =
208 FindOrCreateFloatZeroAndOne(ir_context.get(), float_type);
209
210 // Make the new exit block
211 auto new_exit_block_id = ir_context->TakeNextId();
212 {
213 auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
214 new_exit_block_id,
215 opt::Instruction::OperandList());
216 auto new_exit_block = MakeUnique<opt::BasicBlock>(std::move(label));
217 new_exit_block->AddInstruction(MakeUnique<opt::Instruction>(
218 ir_context.get(), SpvOpReturn, 0, 0, opt::Instruction::OperandList()));
219 entry_point_function->AddBasicBlock(std::move(new_exit_block));
220 }
221
222 // Make the new entry block
223 {
224 auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
225 ir_context->TakeNextId(),
226 opt::Instruction::OperandList());
227 auto new_entry_block = MakeUnique<opt::BasicBlock>(std::move(label));
228
229 // Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing
230 // the colour red.
231 opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}};
232 opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}};
233 opt::Instruction::OperandList op_composite_construct_operands = {
234 one_float, zero_float, zero_float, one_float};
235 auto temp_vec4 = opt::analysis::Vector(float_type, 4);
236 auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4);
237 auto red = MakeUnique<opt::Instruction>(
238 ir_context.get(), SpvOpCompositeConstruct, vec4_id,
239 ir_context->TakeNextId(), op_composite_construct_operands);
240 auto red_id = red->result_id();
241 new_entry_block->AddInstruction(std::move(red));
242
243 // Make an instruction to store red into the output color.
244 opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID,
245 {output_variable->result_id()}};
246 opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}};
247 opt::Instruction::OperandList op_store_operands = {variable_to_store_into,
248 value_to_be_stored};
249 new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
250 ir_context.get(), SpvOpStore, 0, 0, op_store_operands));
251
252 // We are going to attempt to construct 'false' as an expression of the form
253 // 'literal1 > literal2'. If we succeed, we will later replace each literal
254 // with a uniform of the same value - we can only do that replacement once
255 // we have added the entry block to the module.
256 std::unique_ptr<TransformationReplaceConstantWithUniform>
257 first_greater_then_operand_replacement = nullptr;
258 std::unique_ptr<TransformationReplaceConstantWithUniform>
259 second_greater_then_operand_replacement = nullptr;
260 uint32_t id_guaranteed_to_be_false = 0;
261
262 opt::analysis::Bool temp_bool_type;
263 opt::analysis::Bool* registered_bool_type =
264 ir_context->get_type_mgr()
265 ->GetRegisteredType(&temp_bool_type)
266 ->AsBool();
267
268 auto float_type_id = ir_context->get_type_mgr()->GetId(float_type);
269 auto types_for_which_uniforms_are_known =
270 fact_manager.GetTypesForWhichUniformValuesAreKnown();
271
272 // Check whether we have any float uniforms.
273 if (std::find(types_for_which_uniforms_are_known.begin(),
274 types_for_which_uniforms_are_known.end(),
275 float_type_id) != types_for_which_uniforms_are_known.end()) {
276 // We have at least one float uniform; let's see whether we have at least
277 // two.
278 auto available_constants =
279 fact_manager.GetConstantsAvailableFromUniformsForType(
280 ir_context.get(), float_type_id);
281 if (available_constants.size() > 1) {
282 // Grab the float constants associated with the first two known float
283 // uniforms.
284 auto first_constant =
285 ir_context->get_constant_mgr()
286 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
287 available_constants[0]))
288 ->AsFloatConstant();
289 auto second_constant =
290 ir_context->get_constant_mgr()
291 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
292 available_constants[1]))
293 ->AsFloatConstant();
294
295 // Now work out which of the two constants is larger than the other.
296 uint32_t larger_constant_index = 0;
297 uint32_t smaller_constant_index = 0;
298 if (first_constant->GetFloat() > second_constant->GetFloat()) {
299 larger_constant_index = 0;
300 smaller_constant_index = 1;
301 } else if (first_constant->GetFloat() < second_constant->GetFloat()) {
302 larger_constant_index = 1;
303 smaller_constant_index = 0;
304 }
305
306 // Only proceed with these constants if they have turned out to be
307 // distinct.
308 if (larger_constant_index != smaller_constant_index) {
309 // We are in a position to create 'false' as 'literal1 > literal2', so
310 // reserve an id for this computation; this id will end up being
311 // guaranteed to be 'false'.
312 id_guaranteed_to_be_false = ir_context->TakeNextId();
313
314 auto smaller_constant = available_constants[smaller_constant_index];
315 auto larger_constant = available_constants[larger_constant_index];
316
317 opt::Instruction::OperandList greater_than_operands = {
318 {SPV_OPERAND_TYPE_ID, {smaller_constant}},
319 {SPV_OPERAND_TYPE_ID, {larger_constant}}};
320 new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
321 ir_context.get(), SpvOpFOrdGreaterThan,
322 ir_context->get_type_mgr()->GetId(registered_bool_type),
323 id_guaranteed_to_be_false, greater_than_operands));
324
325 first_greater_then_operand_replacement =
326 MakeConstantUniformReplacement(ir_context.get(), fact_manager,
327 smaller_constant,
328 id_guaranteed_to_be_false, 0);
329 second_greater_then_operand_replacement =
330 MakeConstantUniformReplacement(ir_context.get(), fact_manager,
331 larger_constant,
332 id_guaranteed_to_be_false, 1);
333 }
334 }
335 }
336
337 if (id_guaranteed_to_be_false == 0) {
338 auto constant_false = ir_context->get_constant_mgr()->RegisterConstant(
339 MakeUnique<opt::analysis::BoolConstant>(registered_bool_type, false));
340 id_guaranteed_to_be_false = ir_context->get_constant_mgr()
341 ->GetDefiningInstruction(constant_false)
342 ->result_id();
343 }
344
345 opt::Operand false_condition = {SPV_OPERAND_TYPE_ID,
346 {id_guaranteed_to_be_false}};
347 opt::Operand then_block = {SPV_OPERAND_TYPE_ID,
348 {entry_point_function->entry()->id()}};
349 opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}};
350 opt::Instruction::OperandList op_branch_conditional_operands = {
351 false_condition, then_block, else_block};
352 new_entry_block->AddInstruction(
353 MakeUnique<opt::Instruction>(ir_context.get(), SpvOpBranchConditional,
354 0, 0, op_branch_conditional_operands));
355
356 entry_point_function->InsertBasicBlockBefore(
357 std::move(new_entry_block), entry_point_function->entry().get());
358
359 for (auto& replacement : {first_greater_then_operand_replacement.get(),
360 second_greater_then_operand_replacement.get()}) {
361 if (replacement) {
362 assert(replacement->IsApplicable(ir_context.get(),
363 transformation_context));
364 replacement->Apply(ir_context.get(), &transformation_context);
365 }
366 }
367 }
368
369 // Write out the module as a binary.
370 ir_context->module()->ToBinary(binary_out, false);
371 return true;
372 }
373
374 } // namespace fuzz
375 } // namespace spvtools
376