1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_replace_constant_with_uniform.h"
16
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/uniform_buffer_element_descriptor.h"
19
20 namespace spvtools {
21 namespace fuzz {
22
23 TransformationReplaceConstantWithUniform::
TransformationReplaceConstantWithUniform(protobufs::TransformationReplaceConstantWithUniform message)24 TransformationReplaceConstantWithUniform(
25 protobufs::TransformationReplaceConstantWithUniform message)
26 : message_(std::move(message)) {}
27
28 TransformationReplaceConstantWithUniform::
TransformationReplaceConstantWithUniform(protobufs::IdUseDescriptor id_use,protobufs::UniformBufferElementDescriptor uniform_descriptor,uint32_t fresh_id_for_access_chain,uint32_t fresh_id_for_load)29 TransformationReplaceConstantWithUniform(
30 protobufs::IdUseDescriptor id_use,
31 protobufs::UniformBufferElementDescriptor uniform_descriptor,
32 uint32_t fresh_id_for_access_chain, uint32_t fresh_id_for_load) {
33 *message_.mutable_id_use_descriptor() = std::move(id_use);
34 *message_.mutable_uniform_descriptor() = std::move(uniform_descriptor);
35 message_.set_fresh_id_for_access_chain(fresh_id_for_access_chain);
36 message_.set_fresh_id_for_load(fresh_id_for_load);
37 }
38
39 std::unique_ptr<opt::Instruction>
MakeAccessChainInstruction(spvtools::opt::IRContext * ir_context,uint32_t constant_type_id) const40 TransformationReplaceConstantWithUniform::MakeAccessChainInstruction(
41 spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const {
42 // The input operands for the access chain.
43 opt::Instruction::OperandList operands_for_access_chain;
44
45 opt::Instruction* uniform_variable =
46 FindUniformVariable(message_.uniform_descriptor(), ir_context, false);
47
48 // The first input operand is the id of the uniform variable.
49 operands_for_access_chain.push_back(
50 {SPV_OPERAND_TYPE_ID, {uniform_variable->result_id()}});
51
52 // The other input operands are the ids of the constants used to index into
53 // the uniform. The uniform buffer descriptor specifies a series of literals;
54 // for each we find the id of the instruction that defines it, and add these
55 // instruction ids as operands.
56 opt::analysis::Integer int_type(32, true);
57 auto registered_int_type =
58 ir_context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger();
59 auto int_type_id = ir_context->get_type_mgr()->GetId(&int_type);
60 for (auto index : message_.uniform_descriptor().index()) {
61 opt::analysis::IntConstant int_constant(registered_int_type, {index});
62 auto constant_id = ir_context->get_constant_mgr()->FindDeclaredConstant(
63 &int_constant, int_type_id);
64 operands_for_access_chain.push_back({SPV_OPERAND_TYPE_ID, {constant_id}});
65 }
66
67 // The type id for the access chain is a uniform pointer with base type
68 // matching the given constant id type.
69 auto type_and_pointer_type =
70 ir_context->get_type_mgr()->GetTypeAndPointerType(constant_type_id,
71 SpvStorageClassUniform);
72 assert(type_and_pointer_type.first != nullptr);
73 assert(type_and_pointer_type.second != nullptr);
74 auto pointer_to_uniform_constant_type_id =
75 ir_context->get_type_mgr()->GetId(type_and_pointer_type.second.get());
76
77 return MakeUnique<opt::Instruction>(
78 ir_context, SpvOpAccessChain, pointer_to_uniform_constant_type_id,
79 message_.fresh_id_for_access_chain(), operands_for_access_chain);
80 }
81
82 std::unique_ptr<opt::Instruction>
MakeLoadInstruction(spvtools::opt::IRContext * ir_context,uint32_t constant_type_id) const83 TransformationReplaceConstantWithUniform::MakeLoadInstruction(
84 spvtools::opt::IRContext* ir_context, uint32_t constant_type_id) const {
85 opt::Instruction::OperandList operands_for_load = {
86 {SPV_OPERAND_TYPE_ID, {message_.fresh_id_for_access_chain()}}};
87 return MakeUnique<opt::Instruction>(ir_context, SpvOpLoad, constant_type_id,
88 message_.fresh_id_for_load(),
89 operands_for_load);
90 }
91
92 opt::Instruction*
GetInsertBeforeInstruction(opt::IRContext * ir_context) const93 TransformationReplaceConstantWithUniform::GetInsertBeforeInstruction(
94 opt::IRContext* ir_context) const {
95 auto* result =
96 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
97 if (!result) {
98 return nullptr;
99 }
100
101 // The use might be in an OpPhi instruction.
102 if (result->opcode() == SpvOpPhi) {
103 // OpPhi instructions must be the first instructions in a block. Thus, we
104 // can't insert above the OpPhi instruction. Given the predecessor block
105 // that corresponds to the id use, get the last instruction in that block
106 // above which we can insert OpAccessChain and OpLoad.
107 return fuzzerutil::GetLastInsertBeforeInstruction(
108 ir_context,
109 result->GetSingleWordInOperand(
110 message_.id_use_descriptor().in_operand_index() + 1),
111 SpvOpLoad);
112 }
113
114 // The only operand that we could've replaced in the OpBranchConditional is
115 // the condition id. But that operand has a boolean type and uniform variables
116 // can't store booleans (see the spec on OpTypeBool). Thus, |result| can't be
117 // an OpBranchConditional.
118 assert(result->opcode() != SpvOpBranchConditional &&
119 "OpBranchConditional has no operands to replace");
120
121 assert(fuzzerutil::CanInsertOpcodeBeforeInstruction(SpvOpLoad, result) &&
122 "We should be able to insert OpLoad and OpAccessChain at this point");
123 return result;
124 }
125
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const126 bool TransformationReplaceConstantWithUniform::IsApplicable(
127 opt::IRContext* ir_context,
128 const TransformationContext& transformation_context) const {
129 // The following is really an invariant of the transformation rather than
130 // merely a requirement of the precondition. We check it here since we cannot
131 // check it in the message_ constructor.
132 assert(message_.fresh_id_for_access_chain() != message_.fresh_id_for_load() &&
133 "Fresh ids for access chain and load result cannot be the same.");
134
135 // The ids for the access chain and load instructions must both be fresh.
136 if (!fuzzerutil::IsFreshId(ir_context,
137 message_.fresh_id_for_access_chain())) {
138 return false;
139 }
140 if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id_for_load())) {
141 return false;
142 }
143
144 // The id specified in the id use descriptor must be that of a declared scalar
145 // constant.
146 auto declared_constant = ir_context->get_constant_mgr()->FindDeclaredConstant(
147 message_.id_use_descriptor().id_of_interest());
148 if (!declared_constant) {
149 return false;
150 }
151 if (!declared_constant->AsScalarConstant()) {
152 return false;
153 }
154
155 // The fact manager needs to believe that the uniform data element described
156 // by the uniform buffer element descriptor will hold a scalar value.
157 auto constant_id_associated_with_uniform =
158 transformation_context.GetFactManager()->GetConstantFromUniformDescriptor(
159 message_.uniform_descriptor());
160 if (!constant_id_associated_with_uniform) {
161 return false;
162 }
163 auto constant_associated_with_uniform =
164 ir_context->get_constant_mgr()->FindDeclaredConstant(
165 constant_id_associated_with_uniform);
166 assert(constant_associated_with_uniform &&
167 "The constant should be present in the module.");
168 if (!constant_associated_with_uniform->AsScalarConstant()) {
169 return false;
170 }
171
172 // The types and values of the scalar value held in the id specified by the id
173 // use descriptor and in the uniform data element specified by the uniform
174 // buffer element descriptor need to match on both type and value.
175 if (!declared_constant->type()->IsSame(
176 constant_associated_with_uniform->type())) {
177 return false;
178 }
179 if (declared_constant->AsScalarConstant()->words() !=
180 constant_associated_with_uniform->AsScalarConstant()->words()) {
181 return false;
182 }
183
184 // The id use descriptor must identify some instruction with respect to the
185 // module.
186 auto instruction_using_constant =
187 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
188 if (!instruction_using_constant) {
189 return false;
190 }
191
192 // The use must not be a variable initializer; these are required to be
193 // constants, so it would be illegal to replace one with a uniform access.
194 if (instruction_using_constant->opcode() == SpvOpVariable) {
195 return false;
196 }
197
198 // The module needs to have a uniform pointer type suitable for indexing into
199 // the uniform variable, i.e. matching the type of the constant we wish to
200 // replace with a uniform.
201 opt::analysis::Pointer pointer_to_type_of_constant(declared_constant->type(),
202 SpvStorageClassUniform);
203 if (!ir_context->get_type_mgr()->GetId(&pointer_to_type_of_constant)) {
204 return false;
205 }
206
207 // In order to index into the uniform, the module has got to contain the int32
208 // type, plus an OpConstant for each of the indices of interest.
209 opt::analysis::Integer int_type(32, true);
210 if (!ir_context->get_type_mgr()->GetId(&int_type)) {
211 return false;
212 }
213 auto registered_int_type =
214 ir_context->get_type_mgr()->GetRegisteredType(&int_type)->AsInteger();
215 auto int_type_id = ir_context->get_type_mgr()->GetId(&int_type);
216 for (auto index : message_.uniform_descriptor().index()) {
217 opt::analysis::IntConstant int_constant(registered_int_type, {index});
218 if (!ir_context->get_constant_mgr()->FindDeclaredConstant(&int_constant,
219 int_type_id)) {
220 return false;
221 }
222 }
223
224 // Once all checks are completed, we should be able to safely insert
225 // OpAccessChain and OpLoad into the module.
226 assert(GetInsertBeforeInstruction(ir_context) &&
227 "There must exist an instruction that we can use to insert "
228 "OpAccessChain and OpLoad above");
229
230 return true;
231 }
232
Apply(spvtools::opt::IRContext * ir_context,TransformationContext *) const233 void TransformationReplaceConstantWithUniform::Apply(
234 spvtools::opt::IRContext* ir_context,
235 TransformationContext* /*unused*/) const {
236 // Get the instruction that contains the id use we wish to replace.
237 auto* instruction_containing_constant_use =
238 FindInstructionContainingUse(message_.id_use_descriptor(), ir_context);
239 assert(instruction_containing_constant_use &&
240 "Precondition requires that the id use can be found.");
241 assert(instruction_containing_constant_use->GetSingleWordInOperand(
242 message_.id_use_descriptor().in_operand_index()) ==
243 message_.id_use_descriptor().id_of_interest() &&
244 "Does not appear to be a usage of the desired id.");
245
246 // The id of the type for the constant whose use we wish to replace.
247 auto constant_type_id =
248 ir_context->get_def_use_mgr()
249 ->GetDef(message_.id_use_descriptor().id_of_interest())
250 ->type_id();
251
252 // Get an instruction that will be used to insert OpAccessChain and OpLoad.
253 auto* insert_before_inst = GetInsertBeforeInstruction(ir_context);
254 assert(insert_before_inst &&
255 "There must exist an insertion point for OpAccessChain and OpLoad");
256 opt::BasicBlock* enclosing_block =
257 ir_context->get_instr_block(insert_before_inst);
258
259 // Add an access chain instruction to target the uniform element.
260 auto access_chain_instruction =
261 MakeAccessChainInstruction(ir_context, constant_type_id);
262 auto access_chain_instruction_ptr = access_chain_instruction.get();
263 insert_before_inst->InsertBefore(std::move(access_chain_instruction));
264 ir_context->get_def_use_mgr()->AnalyzeInstDefUse(
265 access_chain_instruction_ptr);
266 ir_context->set_instr_block(access_chain_instruction_ptr, enclosing_block);
267
268 // Add a load from this access chain.
269 auto load_instruction = MakeLoadInstruction(ir_context, constant_type_id);
270 auto load_instruction_ptr = load_instruction.get();
271 insert_before_inst->InsertBefore(std::move(load_instruction));
272 ir_context->get_def_use_mgr()->AnalyzeInstDefUse(load_instruction_ptr);
273 ir_context->set_instr_block(load_instruction_ptr, enclosing_block);
274
275 // Adjust the instruction containing the usage of the constant so that this
276 // usage refers instead to the result of the load.
277 instruction_containing_constant_use->SetInOperand(
278 message_.id_use_descriptor().in_operand_index(),
279 {message_.fresh_id_for_load()});
280 ir_context->get_def_use_mgr()->EraseUseRecordsOfOperandIds(
281 instruction_containing_constant_use);
282 ir_context->get_def_use_mgr()->AnalyzeInstUse(
283 instruction_containing_constant_use);
284
285 // Update the module id bound to reflect the new instructions.
286 fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id_for_load());
287 fuzzerutil::UpdateModuleIdBound(ir_context,
288 message_.fresh_id_for_access_chain());
289 }
290
ToMessage() const291 protobufs::Transformation TransformationReplaceConstantWithUniform::ToMessage()
292 const {
293 protobufs::Transformation result;
294 *result.mutable_replace_constant_with_uniform() = message_;
295 return result;
296 }
297
298 std::unordered_set<uint32_t>
GetFreshIds() const299 TransformationReplaceConstantWithUniform::GetFreshIds() const {
300 return {message_.fresh_id_for_access_chain(), message_.fresh_id_for_load()};
301 }
302
303 } // namespace fuzz
304 } // namespace spvtools
305