1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/force_render_red.h"
16
17 #include "source/fuzz/fact_manager/fact_manager.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 #include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
20 #include "source/fuzz/transformation_context.h"
21 #include "source/fuzz/transformation_replace_constant_with_uniform.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/ir_context.h"
24 #include "source/opt/types.h"
25 #include "source/util/make_unique.h"
26
27 namespace spvtools {
28 namespace fuzz {
29
30 namespace {
31
32 // Helper method to find the fragment shader entry point, complaining if there
33 // is no shader or if there is no fragment entry point.
FindFragmentShaderEntryPoint(opt::IRContext * ir_context,MessageConsumer message_consumer)34 opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context,
35 MessageConsumer message_consumer) {
36 // Check that this is a fragment shader
37 bool found_capability_shader = false;
38 for (auto& capability : ir_context->capabilities()) {
39 assert(capability.opcode() == SpvOpCapability);
40 if (capability.GetSingleWordInOperand(0) == SpvCapabilityShader) {
41 found_capability_shader = true;
42 break;
43 }
44 }
45 if (!found_capability_shader) {
46 message_consumer(
47 SPV_MSG_ERROR, nullptr, {},
48 "Forcing of red rendering requires the Shader capability.");
49 return nullptr;
50 }
51
52 opt::Instruction* fragment_entry_point = nullptr;
53 for (auto& entry_point : ir_context->module()->entry_points()) {
54 if (entry_point.GetSingleWordInOperand(0) == SpvExecutionModelFragment) {
55 fragment_entry_point = &entry_point;
56 break;
57 }
58 }
59 if (fragment_entry_point == nullptr) {
60 message_consumer(SPV_MSG_ERROR, nullptr, {},
61 "Forcing of red rendering requires an entry point with "
62 "the Fragment execution model.");
63 return nullptr;
64 }
65
66 for (auto& function : *ir_context->module()) {
67 if (function.result_id() ==
68 fragment_entry_point->GetSingleWordInOperand(1)) {
69 return &function;
70 }
71 }
72 assert(
73 false &&
74 "A valid module must have a function associate with each entry point.");
75 return nullptr;
76 }
77
78 // Helper method to check that there is a single vec4 output variable and get a
79 // pointer to it.
FindVec4OutputVariable(opt::IRContext * ir_context,MessageConsumer message_consumer)80 opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context,
81 MessageConsumer message_consumer) {
82 opt::Instruction* output_variable = nullptr;
83 for (auto& inst : ir_context->types_values()) {
84 if (inst.opcode() == SpvOpVariable &&
85 inst.GetSingleWordInOperand(0) == SpvStorageClassOutput) {
86 if (output_variable != nullptr) {
87 message_consumer(SPV_MSG_ERROR, nullptr, {},
88 "Only one output variable can be handled at present; "
89 "found multiple.");
90 return nullptr;
91 }
92 output_variable = &inst;
93 // Do not break, as we want to check for multiple output variables.
94 }
95 }
96 if (output_variable == nullptr) {
97 message_consumer(SPV_MSG_ERROR, nullptr, {},
98 "No output variable to which to write red was found.");
99 return nullptr;
100 }
101
102 auto output_variable_base_type = ir_context->get_type_mgr()
103 ->GetType(output_variable->type_id())
104 ->AsPointer()
105 ->pointee_type()
106 ->AsVector();
107 if (!output_variable_base_type ||
108 output_variable_base_type->element_count() != 4 ||
109 !output_variable_base_type->element_type()->AsFloat()) {
110 message_consumer(SPV_MSG_ERROR, nullptr, {},
111 "The output variable must have type vec4.");
112 return nullptr;
113 }
114
115 return output_variable;
116 }
117
118 // Helper to get the ids of float constants 0.0 and 1.0, creating them if
119 // necessary.
FindOrCreateFloatZeroAndOne(opt::IRContext * ir_context,opt::analysis::Float * float_type)120 std::pair<uint32_t, uint32_t> FindOrCreateFloatZeroAndOne(
121 opt::IRContext* ir_context, opt::analysis::Float* float_type) {
122 float one = 1.0;
123 uint32_t one_as_uint;
124 memcpy(&one_as_uint, &one, sizeof(float));
125 std::vector<uint32_t> zero_bytes = {0};
126 std::vector<uint32_t> one_bytes = {one_as_uint};
127 auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant(
128 MakeUnique<opt::analysis::FloatConstant>(float_type, zero_bytes));
129 auto constant_one = ir_context->get_constant_mgr()->RegisterConstant(
130 MakeUnique<opt::analysis::FloatConstant>(float_type, one_bytes));
131 auto constant_zero_id = ir_context->get_constant_mgr()
132 ->GetDefiningInstruction(constant_zero)
133 ->result_id();
134 auto constant_one_id = ir_context->get_constant_mgr()
135 ->GetDefiningInstruction(constant_one)
136 ->result_id();
137 return std::pair<uint32_t, uint32_t>(constant_zero_id, constant_one_id);
138 }
139
140 std::unique_ptr<TransformationReplaceConstantWithUniform>
MakeConstantUniformReplacement(opt::IRContext * ir_context,const FactManager & fact_manager,uint32_t constant_id,uint32_t greater_than_instruction,uint32_t in_operand_index)141 MakeConstantUniformReplacement(opt::IRContext* ir_context,
142 const FactManager& fact_manager,
143 uint32_t constant_id,
144 uint32_t greater_than_instruction,
145 uint32_t in_operand_index) {
146 return MakeUnique<TransformationReplaceConstantWithUniform>(
147 MakeIdUseDescriptor(constant_id,
148 MakeInstructionDescriptor(greater_than_instruction,
149 SpvOpFOrdGreaterThan, 0),
150 in_operand_index),
151 fact_manager.GetUniformDescriptorsForConstant(constant_id)[0],
152 ir_context->TakeNextId(), ir_context->TakeNextId());
153 }
154
155 } // namespace
156
ForceRenderRed(const spv_target_env & target_env,spv_validator_options validator_options,const std::vector<uint32_t> & binary_in,const spvtools::fuzz::protobufs::FactSequence & initial_facts,const MessageConsumer & message_consumer,std::vector<uint32_t> * binary_out)157 bool ForceRenderRed(
158 const spv_target_env& target_env, spv_validator_options validator_options,
159 const std::vector<uint32_t>& binary_in,
160 const spvtools::fuzz::protobufs::FactSequence& initial_facts,
161 const MessageConsumer& message_consumer,
162 std::vector<uint32_t>* binary_out) {
163 spvtools::SpirvTools tools(target_env);
164 if (!tools.IsValid()) {
165 message_consumer(SPV_MSG_ERROR, nullptr, {},
166 "Failed to create SPIRV-Tools interface; stopping.");
167 return false;
168 }
169
170 // Initial binary should be valid.
171 if (!tools.Validate(&binary_in[0], binary_in.size(), validator_options)) {
172 message_consumer(SPV_MSG_ERROR, nullptr, {},
173 "Initial binary is invalid; stopping.");
174 return false;
175 }
176
177 // Build the module from the input binary.
178 std::unique_ptr<opt::IRContext> ir_context = BuildModule(
179 target_env, message_consumer, binary_in.data(), binary_in.size());
180 assert(ir_context);
181
182 // Set up a fact manager with any given initial facts.
183 TransformationContext transformation_context(
184 MakeUnique<FactManager>(ir_context.get()), validator_options);
185 for (auto& fact : initial_facts.fact()) {
186 transformation_context.GetFactManager()->MaybeAddFact(fact);
187 }
188
189 auto entry_point_function =
190 FindFragmentShaderEntryPoint(ir_context.get(), message_consumer);
191 auto output_variable =
192 FindVec4OutputVariable(ir_context.get(), message_consumer);
193 if (entry_point_function == nullptr || output_variable == nullptr) {
194 return false;
195 }
196
197 opt::analysis::Float temp_float_type(32);
198 opt::analysis::Float* float_type = ir_context->get_type_mgr()
199 ->GetRegisteredType(&temp_float_type)
200 ->AsFloat();
201 std::pair<uint32_t, uint32_t> zero_one_float_ids =
202 FindOrCreateFloatZeroAndOne(ir_context.get(), float_type);
203
204 // Make the new exit block
205 auto new_exit_block_id = ir_context->TakeNextId();
206 {
207 auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
208 new_exit_block_id,
209 opt::Instruction::OperandList());
210 auto new_exit_block = MakeUnique<opt::BasicBlock>(std::move(label));
211 new_exit_block->AddInstruction(MakeUnique<opt::Instruction>(
212 ir_context.get(), SpvOpReturn, 0, 0, opt::Instruction::OperandList()));
213 entry_point_function->AddBasicBlock(std::move(new_exit_block));
214 }
215
216 // Make the new entry block
217 {
218 auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
219 ir_context->TakeNextId(),
220 opt::Instruction::OperandList());
221 auto new_entry_block = MakeUnique<opt::BasicBlock>(std::move(label));
222
223 // Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing
224 // the colour red.
225 opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}};
226 opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}};
227 opt::Instruction::OperandList op_composite_construct_operands = {
228 one_float, zero_float, zero_float, one_float};
229 auto temp_vec4 = opt::analysis::Vector(float_type, 4);
230 auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4);
231 auto red = MakeUnique<opt::Instruction>(
232 ir_context.get(), SpvOpCompositeConstruct, vec4_id,
233 ir_context->TakeNextId(), op_composite_construct_operands);
234 auto red_id = red->result_id();
235 new_entry_block->AddInstruction(std::move(red));
236
237 // Make an instruction to store red into the output color.
238 opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID,
239 {output_variable->result_id()}};
240 opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}};
241 opt::Instruction::OperandList op_store_operands = {variable_to_store_into,
242 value_to_be_stored};
243 new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
244 ir_context.get(), SpvOpStore, 0, 0, op_store_operands));
245
246 // We are going to attempt to construct 'false' as an expression of the form
247 // 'literal1 > literal2'. If we succeed, we will later replace each literal
248 // with a uniform of the same value - we can only do that replacement once
249 // we have added the entry block to the module.
250 std::unique_ptr<TransformationReplaceConstantWithUniform>
251 first_greater_then_operand_replacement = nullptr;
252 std::unique_ptr<TransformationReplaceConstantWithUniform>
253 second_greater_then_operand_replacement = nullptr;
254 uint32_t id_guaranteed_to_be_false = 0;
255
256 opt::analysis::Bool temp_bool_type;
257 opt::analysis::Bool* registered_bool_type =
258 ir_context->get_type_mgr()
259 ->GetRegisteredType(&temp_bool_type)
260 ->AsBool();
261
262 auto float_type_id = ir_context->get_type_mgr()->GetId(float_type);
263 auto types_for_which_uniforms_are_known =
264 transformation_context.GetFactManager()
265 ->GetTypesForWhichUniformValuesAreKnown();
266
267 // Check whether we have any float uniforms.
268 if (std::find(types_for_which_uniforms_are_known.begin(),
269 types_for_which_uniforms_are_known.end(),
270 float_type_id) != types_for_which_uniforms_are_known.end()) {
271 // We have at least one float uniform; let's see whether we have at least
272 // two.
273 auto available_constants =
274 transformation_context.GetFactManager()
275 ->GetConstantsAvailableFromUniformsForType(float_type_id);
276 if (available_constants.size() > 1) {
277 // Grab the float constants associated with the first two known float
278 // uniforms.
279 auto first_constant =
280 ir_context->get_constant_mgr()
281 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
282 available_constants[0]))
283 ->AsFloatConstant();
284 auto second_constant =
285 ir_context->get_constant_mgr()
286 ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
287 available_constants[1]))
288 ->AsFloatConstant();
289
290 // Now work out which of the two constants is larger than the other.
291 uint32_t larger_constant_index = 0;
292 uint32_t smaller_constant_index = 0;
293 if (first_constant->GetFloat() > second_constant->GetFloat()) {
294 larger_constant_index = 0;
295 smaller_constant_index = 1;
296 } else if (first_constant->GetFloat() < second_constant->GetFloat()) {
297 larger_constant_index = 1;
298 smaller_constant_index = 0;
299 }
300
301 // Only proceed with these constants if they have turned out to be
302 // distinct.
303 if (larger_constant_index != smaller_constant_index) {
304 // We are in a position to create 'false' as 'literal1 > literal2', so
305 // reserve an id for this computation; this id will end up being
306 // guaranteed to be 'false'.
307 id_guaranteed_to_be_false = ir_context->TakeNextId();
308
309 auto smaller_constant = available_constants[smaller_constant_index];
310 auto larger_constant = available_constants[larger_constant_index];
311
312 opt::Instruction::OperandList greater_than_operands = {
313 {SPV_OPERAND_TYPE_ID, {smaller_constant}},
314 {SPV_OPERAND_TYPE_ID, {larger_constant}}};
315 new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
316 ir_context.get(), SpvOpFOrdGreaterThan,
317 ir_context->get_type_mgr()->GetId(registered_bool_type),
318 id_guaranteed_to_be_false, greater_than_operands));
319
320 first_greater_then_operand_replacement =
321 MakeConstantUniformReplacement(
322 ir_context.get(), *transformation_context.GetFactManager(),
323 smaller_constant, id_guaranteed_to_be_false, 0);
324 second_greater_then_operand_replacement =
325 MakeConstantUniformReplacement(
326 ir_context.get(), *transformation_context.GetFactManager(),
327 larger_constant, id_guaranteed_to_be_false, 1);
328 }
329 }
330 }
331
332 if (id_guaranteed_to_be_false == 0) {
333 auto constant_false = ir_context->get_constant_mgr()->RegisterConstant(
334 MakeUnique<opt::analysis::BoolConstant>(registered_bool_type, false));
335 id_guaranteed_to_be_false = ir_context->get_constant_mgr()
336 ->GetDefiningInstruction(constant_false)
337 ->result_id();
338 }
339
340 opt::Operand false_condition = {SPV_OPERAND_TYPE_ID,
341 {id_guaranteed_to_be_false}};
342 opt::Operand then_block = {SPV_OPERAND_TYPE_ID,
343 {entry_point_function->entry()->id()}};
344 opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}};
345 opt::Instruction::OperandList op_branch_conditional_operands = {
346 false_condition, then_block, else_block};
347 new_entry_block->AddInstruction(
348 MakeUnique<opt::Instruction>(ir_context.get(), SpvOpBranchConditional,
349 0, 0, op_branch_conditional_operands));
350
351 entry_point_function->InsertBasicBlockBefore(
352 std::move(new_entry_block), entry_point_function->entry().get());
353
354 for (auto& replacement : {first_greater_then_operand_replacement.get(),
355 second_greater_then_operand_replacement.get()}) {
356 if (replacement) {
357 assert(replacement->IsApplicable(ir_context.get(),
358 transformation_context));
359 replacement->Apply(ir_context.get(), &transformation_context);
360 }
361 }
362 }
363
364 // Write out the module as a binary.
365 ir_context->module()->ToBinary(binary_out, false);
366 return true;
367 }
368
369 } // namespace fuzz
370 } // namespace spvtools
371