1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_replace_constant_with_uniform.h"
16
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/uniform_buffer_element_descriptor.h"
19
20 namespace spvtools {
21 namespace fuzz {
22
23 TransformationReplaceConstantWithUniform::
TransformationReplaceConstantWithUniform(protobufs::TransformationReplaceConstantWithUniform message)24 TransformationReplaceConstantWithUniform(
25 protobufs::TransformationReplaceConstantWithUniform message)
26 : message_(std::move(message)) {}
27
28 TransformationReplaceConstantWithUniform::
TransformationReplaceConstantWithUniform(protobufs::IdUseDescriptor id_use,protobufs::UniformBufferElementDescriptor uniform_descriptor,uint32_t fresh_id_for_access_chain,uint32_t fresh_id_for_load)29 TransformationReplaceConstantWithUniform(
30 protobufs::IdUseDescriptor id_use,
31 protobufs::UniformBufferElementDescriptor uniform_descriptor,
32 uint32_t fresh_id_for_access_chain, uint32_t fresh_id_for_load) {
33 *message_.mutable_id_use_descriptor() = std::move(id_use);
34 *message_.mutable_uniform_descriptor() = std::move(uniform_descriptor);
35 message_.set_fresh_id_for_access_chain(fresh_id_for_access_chain);
36 message_.set_fresh_id_for_load(fresh_id_for_load);
37 }
38
39 std::unique_ptr<opt::Instruction>
MakeAccessChainInstruction(spvtools::opt::IRContext * ir_context,uint32_t constant_type_id) const40 TransformationReplaceConstantWithUniform::MakeAccessChainInstruction(
41 spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const {
42 // The input operands for the access chain.
43 opt::Instruction::OperandList operands_for_access_chain;
44
45 opt::Instruction* uniform_variable =
46 FindUniformVariable(message_.uniform_descriptor(), ir_context, false);
47
48 // The first input operand is the id of the uniform variable.
49 operands_for_access_chain.push_back(
50 {SPV_OPERAND_TYPE_ID, {uniform_variable->result_id()}});
51
52 // The other input operands are the ids of the constants used to index into
53 // the uniform. The uniform buffer descriptor specifies a series of literals;
54 // for each we find the id of the instruction that defines it, and add these
55 // instruction ids as operands.
56 opt::analysis::Integer int_type(32, true);
57 auto registered_int_type =
58 ir_context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger();
59 auto int_type_id = ir_context->get_type_mgr()->GetId(&int_type);
60 for (auto index : message_.uniform_descriptor().index()) {
61 opt::analysis::IntConstant int_constant(registered_int_type, {index});
62 auto constant_id = ir_context->get_constant_mgr()->FindDeclaredConstant(
63 &int_constant, int_type_id);
64 operands_for_access_chain.push_back({SPV_OPERAND_TYPE_ID, {constant_id}});
65 }
66
67 // The type id for the access chain is a uniform pointer with base type
68 // matching the given constant id type.
69 auto type_and_pointer_type =
70 ir_context->get_type_mgr()->GetTypeAndPointerType(
71 constant_type_id, spv::StorageClass::Uniform);
72 assert(type_and_pointer_type.first != nullptr);
73 assert(type_and_pointer_type.second != nullptr);
74 auto pointer_to_uniform_constant_type_id =
75 ir_context->get_type_mgr()->GetId(type_and_pointer_type.second.get());
76
77 return MakeUnique<opt::Instruction>(
78 ir_context, spv::Op::OpAccessChain, pointer_to_uniform_constant_type_id,
79 message_.fresh_id_for_access_chain(), operands_for_access_chain);
80 }
81
82 std::unique_ptr<opt::Instruction>
MakeLoadInstruction(spvtools::opt::IRContext * ir_context,uint32_t constant_type_id) const83 TransformationReplaceConstantWithUniform::MakeLoadInstruction(
84 spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const {
85 opt::Instruction::OperandList operands_for_load = {
86 {SPV_OPERAND_TYPE_ID, {message_.fresh_id_for_access_chain()}}};
87 return MakeUnique<opt::Instruction>(
88 ir_context, spv::Op::OpLoad, constant_type_id,
89 message_.fresh_id_for_load(), operands_for_load);
90 }
91
92 opt::Instruction*
GetInsertBeforeInstruction(opt::IRContext * ir_context) const93 TransformationReplaceConstantWithUniform::GetInsertBeforeInstruction(
94 opt::IRContext* ir_context) const {
95 auto* result =
96 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
97 if (!result) {
98 return nullptr;
99 }
100
101 // The use might be in an OpPhi instruction.
102 if (result->opcode() == spv::Op::OpPhi) {
103 // OpPhi instructions must be the first instructions in a block. Thus, we
104 // can't insert above the OpPhi instruction. Given the predecessor block
105 // that corresponds to the id use, get the last instruction in that block
106 // above which we can insert OpAccessChain and OpLoad.
107 return fuzzerutil::GetLastInsertBeforeInstruction(
108 ir_context,
109 result->GetSingleWordInOperand(
110 message_.id_use_descriptor().in_operand_index() + 1),
111 spv::Op::OpLoad);
112 }
113
114 // The only operand that we could've replaced in the OpBranchConditional is
115 // the condition id. But that operand has a boolean type and uniform variables
116 // can't store booleans (see the spec on OpTypeBool). Thus, |result| can't be
117 // an OpBranchConditional.
118 assert(result->opcode() != spv::Op::OpBranchConditional &&
119 "OpBranchConditional has no operands to replace");
120
121 assert(
122 fuzzerutil::CanInsertOpcodeBeforeInstruction(spv::Op::OpLoad, result) &&
123 "We should be able to insert OpLoad and OpAccessChain at this point");
124 return result;
125 }
126
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const127 bool TransformationReplaceConstantWithUniform::IsApplicable(
128 opt::IRContext* ir_context,
129 const TransformationContext& transformation_context) const {
130 // The following is really an invariant of the transformation rather than
131 // merely a requirement of the precondition. We check it here since we cannot
132 // check it in the message_ constructor.
133 assert(message_.fresh_id_for_access_chain() != message_.fresh_id_for_load() &&
134 "Fresh ids for access chain and load result cannot be the same.");
135
136 // The ids for the access chain and load instructions must both be fresh.
137 if (!fuzzerutil::IsFreshId(ir_context,
138 message_.fresh_id_for_access_chain())) {
139 return false;
140 }
141 if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id_for_load())) {
142 return false;
143 }
144
145 // The id specified in the id use descriptor must be that of a declared scalar
146 // constant.
147 auto declared_constant = ir_context->get_constant_mgr()->FindDeclaredConstant(
148 message_.id_use_descriptor().id_of_interest());
149 if (!declared_constant) {
150 return false;
151 }
152 if (!declared_constant->AsScalarConstant()) {
153 return false;
154 }
155
156 // The fact manager needs to believe that the uniform data element described
157 // by the uniform buffer element descriptor will hold a scalar value.
158 auto constant_id_associated_with_uniform =
159 transformation_context.GetFactManager()->GetConstantFromUniformDescriptor(
160 message_.uniform_descriptor());
161 if (!constant_id_associated_with_uniform) {
162 return false;
163 }
164 auto constant_associated_with_uniform =
165 ir_context->get_constant_mgr()->FindDeclaredConstant(
166 constant_id_associated_with_uniform);
167 assert(constant_associated_with_uniform &&
168 "The constant should be present in the module.");
169 if (!constant_associated_with_uniform->AsScalarConstant()) {
170 return false;
171 }
172
173 // The types and values of the scalar value held in the id specified by the id
174 // use descriptor and in the uniform data element specified by the uniform
175 // buffer element descriptor need to match on both type and value.
176 if (!declared_constant->type()->IsSame(
177 constant_associated_with_uniform->type())) {
178 return false;
179 }
180 if (declared_constant->AsScalarConstant()->words() !=
181 constant_associated_with_uniform->AsScalarConstant()->words()) {
182 return false;
183 }
184
185 // The id use descriptor must identify some instruction with respect to the
186 // module.
187 auto instruction_using_constant =
188 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
189 if (!instruction_using_constant) {
190 return false;
191 }
192
193 // The use must not be a variable initializer; these are required to be
194 // constants, so it would be illegal to replace one with a uniform access.
195 if (instruction_using_constant->opcode() == spv::Op::OpVariable) {
196 return false;
197 }
198
199 // The module needs to have a uniform pointer type suitable for indexing into
200 // the uniform variable, i.e. matching the type of the constant we wish to
201 // replace with a uniform.
202 opt::analysis::Pointer pointer_to_type_of_constant(
203 declared_constant->type(), spv::StorageClass::Uniform);
204 if (!ir_context->get_type_mgr()->GetId(&pointer_to_type_of_constant)) {
205 return false;
206 }
207
208 // In order to index into the uniform, the module has got to contain the int32
209 // type, plus an OpConstant for each of the indices of interest.
210 opt::analysis::Integer int_type(32, true);
211 if (!ir_context->get_type_mgr()->GetId(&int_type)) {
212 return false;
213 }
214 auto registered_int_type =
215 ir_context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger();
216 auto int_type_id = ir_context->get_type_mgr()->GetId(&int_type);
217 for (auto index : message_.uniform_descriptor().index()) {
218 opt::analysis::IntConstant int_constant(registered_int_type, {index});
219 if (!ir_context->get_constant_mgr()->FindDeclaredConstant(&int_constant,
220 int_type_id)) {
221 return false;
222 }
223 }
224
225 // Once all checks are completed, we should be able to safely insert
226 // OpAccessChain and OpLoad into the module.
227 assert(GetInsertBeforeInstruction(ir_context) &&
228 "There must exist an instruction that we can use to insert "
229 "OpAccessChain and OpLoad above");
230
231 return true;
232 }
233
Apply(spvtools::opt::IRContext * ir_context,TransformationContext *) const234 void TransformationReplaceConstantWithUniform::Apply(
235 spvtools::opt::IRContext* ir_context,
236 TransformationContext* /*unused*/) const {
237 // Get the instruction that contains the id use we wish to replace.
238 auto* instruction_containing_constant_use =
239 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
240 assert(instruction_containing_constant_use &&
241 "Precondition requires that the id use can be found.");
242 assert(instruction_containing_constant_use->GetSingleWordInOperand(
243 message_.id_use_descriptor().in_operand_index()) ==
244 message_.id_use_descriptor().id_of_interest() &&
245 "Does not appear to be a usage of the desired id.");
246
247 // The id of the type for the constant whose use we wish to replace.
248 auto constant_type_id =
249 ir_context->get_def_use_mgr()
250 ->GetDef(message_.id_use_descriptor().id_of_interest())
251 ->type_id();
252
253 // Get an instruction that will be used to insert OpAccessChain and OpLoad.
254 auto* insert_before_inst = GetInsertBeforeInstruction(ir_context);
255 assert(insert_before_inst &&
256 "There must exist an insertion point for OpAccessChain and OpLoad");
257 opt::BasicBlock* enclosing_block =
258 ir_context->get_instr_block(insert_before_inst);
259
260 // Add an access chain instruction to target the uniform element.
261 auto access_chain_instruction =
262 MakeAccessChainInstruction(ir_context, constant_type_id);
263 auto access_chain_instruction_ptr = access_chain_instruction.get();
264 insert_before_inst->InsertBefore(std::move(access_chain_instruction));
265 ir_context->get_def_use_mgr()->AnalyzeInstDefUse(
266 access_chain_instruction_ptr);
267 ir_context->set_instr_block(access_chain_instruction_ptr, enclosing_block);
268
269 // Add a load from this access chain.
270 auto load_instruction = MakeLoadInstruction(ir_context, constant_type_id);
271 auto load_instruction_ptr = load_instruction.get();
272 insert_before_inst->InsertBefore(std::move(load_instruction));
273 ir_context->get_def_use_mgr()->AnalyzeInstDefUse(load_instruction_ptr);
274 ir_context->set_instr_block(load_instruction_ptr, enclosing_block);
275
276 // Adjust the instruction containing the usage of the constant so that this
277 // usage refers instead to the result of the load.
278 instruction_containing_constant_use->SetInOperand(
279 message_.id_use_descriptor().in_operand_index(),
280 {message_.fresh_id_for_load()});
281 ir_context->get_def_use_mgr()->EraseUseRecordsOfOperandIds(
282 instruction_containing_constant_use);
283 ir_context->get_def_use_mgr()->AnalyzeInstUse(
284 instruction_containing_constant_use);
285
286 // Update the module id bound to reflect the new instructions.
287 fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id_for_load());
288 fuzzerutil::UpdateModuleIdBound(ir_context,
289 message_.fresh_id_for_access_chain());
290 }
291
ToMessage() const292 protobufs::Transformation TransformationReplaceConstantWithUniform::ToMessage()
293 const {
294 protobufs::Transformation result;
295 *result.mutable_replace_constant_with_uniform() = message_;
296 return result;
297 }
298
299 std::unordered_set<uint32_t>
GetFreshIds() const300 TransformationReplaceConstantWithUniform::GetFreshIds() const {
301 return {message_.fresh_id_for_access_chain(), message_.fresh_id_for_load()};
302 }
303
304 } // namespace fuzz
305 } // namespace spvtools
306