1 // Copyright (c) 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_equation_instruction.h"
16
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19
20 namespace spvtools {
21 namespace fuzz {
22
TransformationEquationInstruction(protobufs::TransformationEquationInstruction message)23 TransformationEquationInstruction::TransformationEquationInstruction(
24 protobufs::TransformationEquationInstruction message)
25 : message_(std::move(message)) {}
26
TransformationEquationInstruction(uint32_t fresh_id,SpvOp opcode,const std::vector<uint32_t> & in_operand_id,const protobufs::InstructionDescriptor & instruction_to_insert_before)27 TransformationEquationInstruction::TransformationEquationInstruction(
28 uint32_t fresh_id, SpvOp opcode, const std::vector<uint32_t>& in_operand_id,
29 const protobufs::InstructionDescriptor& instruction_to_insert_before) {
30 message_.set_fresh_id(fresh_id);
31 message_.set_opcode(opcode);
32 for (auto id : in_operand_id) {
33 message_.add_in_operand_id(id);
34 }
35 *message_.mutable_instruction_to_insert_before() =
36 instruction_to_insert_before;
37 }
38
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const39 bool TransformationEquationInstruction::IsApplicable(
40 opt::IRContext* ir_context,
41 const TransformationContext& transformation_context) const {
42 // The result id must be fresh.
43 if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) {
44 return false;
45 }
46
47 // The instruction to insert before must exist.
48 auto insert_before =
49 FindInstruction(message_.instruction_to_insert_before(), ir_context);
50 if (!insert_before) {
51 return false;
52 }
53 // The input ids must all exist, not be OpUndef, not be irrelevant, and be
54 // available before this instruction.
55 for (auto id : message_.in_operand_id()) {
56 auto inst = ir_context->get_def_use_mgr()->GetDef(id);
57 if (!inst) {
58 return false;
59 }
60 if (inst->opcode() == SpvOpUndef) {
61 return false;
62 }
63 if (transformation_context.GetFactManager()->IdIsIrrelevant(id)) {
64 return false;
65 }
66 if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before,
67 id)) {
68 return false;
69 }
70 }
71
72 return MaybeGetResultTypeId(ir_context) != 0;
73 }
74
Apply(opt::IRContext * ir_context,TransformationContext * transformation_context) const75 void TransformationEquationInstruction::Apply(
76 opt::IRContext* ir_context,
77 TransformationContext* transformation_context) const {
78 fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id());
79
80 opt::Instruction::OperandList in_operands;
81 std::vector<uint32_t> rhs_id;
82 for (auto id : message_.in_operand_id()) {
83 in_operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
84 rhs_id.push_back(id);
85 }
86
87 auto insert_before =
88 FindInstruction(message_.instruction_to_insert_before(), ir_context);
89 opt::Instruction* new_instruction =
90 insert_before->InsertBefore(MakeUnique<opt::Instruction>(
91 ir_context, static_cast<SpvOp>(message_.opcode()),
92 MaybeGetResultTypeId(ir_context), message_.fresh_id(),
93 std::move(in_operands)));
94
95 ir_context->get_def_use_mgr()->AnalyzeInstDefUse(new_instruction);
96 ir_context->set_instr_block(new_instruction,
97 ir_context->get_instr_block(insert_before));
98
99 // Add an equation fact as long as the result id is not irrelevant (it could
100 // be if we are inserting into a dead block).
101 if (!transformation_context->GetFactManager()->IdIsIrrelevant(
102 message_.fresh_id())) {
103 transformation_context->GetFactManager()->AddFactIdEquation(
104 message_.fresh_id(), static_cast<SpvOp>(message_.opcode()), rhs_id);
105 }
106 }
107
ToMessage() const108 protobufs::Transformation TransformationEquationInstruction::ToMessage() const {
109 protobufs::Transformation result;
110 *result.mutable_equation_instruction() = message_;
111 return result;
112 }
113
MaybeGetResultTypeId(opt::IRContext * ir_context) const114 uint32_t TransformationEquationInstruction::MaybeGetResultTypeId(
115 opt::IRContext* ir_context) const {
116 auto opcode = static_cast<SpvOp>(message_.opcode());
117 switch (opcode) {
118 case SpvOpConvertUToF:
119 case SpvOpConvertSToF: {
120 if (message_.in_operand_id_size() != 1) {
121 return 0;
122 }
123
124 const auto* type = ir_context->get_type_mgr()->GetType(
125 fuzzerutil::GetTypeId(ir_context, message_.in_operand_id(0)));
126 if (!type) {
127 return 0;
128 }
129
130 if (const auto* vector = type->AsVector()) {
131 if (!vector->element_type()->AsInteger()) {
132 return 0;
133 }
134
135 if (auto element_type_id = fuzzerutil::MaybeGetFloatType(
136 ir_context, vector->element_type()->AsInteger()->width())) {
137 return fuzzerutil::MaybeGetVectorType(ir_context, element_type_id,
138 vector->element_count());
139 }
140
141 return 0;
142 } else {
143 if (!type->AsInteger()) {
144 return 0;
145 }
146
147 return fuzzerutil::MaybeGetFloatType(ir_context,
148 type->AsInteger()->width());
149 }
150 }
151 case SpvOpBitcast: {
152 if (message_.in_operand_id_size() != 1) {
153 return 0;
154 }
155
156 const auto* operand_inst =
157 ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
158 if (!operand_inst) {
159 return 0;
160 }
161
162 const auto* operand_type =
163 ir_context->get_type_mgr()->GetType(operand_inst->type_id());
164 if (!operand_type) {
165 return 0;
166 }
167
168 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3539):
169 // The only constraint on the types of OpBitcast's parameters is that
170 // they must have the same number of bits. Consider improving the code
171 // below to support this in full.
172 if (const auto* vector = operand_type->AsVector()) {
173 uint32_t component_type_id;
174 if (const auto* int_type = vector->element_type()->AsInteger()) {
175 component_type_id =
176 fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
177 } else if (const auto* float_type = vector->element_type()->AsFloat()) {
178 component_type_id = fuzzerutil::MaybeGetIntegerType(
179 ir_context, float_type->width(), true);
180 if (component_type_id == 0 ||
181 fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
182 vector->element_count()) == 0) {
183 component_type_id = fuzzerutil::MaybeGetIntegerType(
184 ir_context, float_type->width(), false);
185 }
186 } else {
187 assert(false && "Only vectors of numerical components are supported");
188 return 0;
189 }
190
191 if (component_type_id == 0) {
192 return 0;
193 }
194
195 return fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
196 vector->element_count());
197 } else if (const auto* int_type = operand_type->AsInteger()) {
198 return fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
199 } else if (const auto* float_type = operand_type->AsFloat()) {
200 if (auto existing_id = fuzzerutil::MaybeGetIntegerType(
201 ir_context, float_type->width(), true)) {
202 return existing_id;
203 }
204
205 return fuzzerutil::MaybeGetIntegerType(ir_context, float_type->width(),
206 false);
207 } else {
208 assert(false &&
209 "Operand is not a scalar or a vector of numerical type");
210 return 0;
211 }
212 }
213 case SpvOpIAdd:
214 case SpvOpISub: {
215 if (message_.in_operand_id_size() != 2) {
216 return 0;
217 }
218 uint32_t first_operand_width = 0;
219 uint32_t first_operand_type_id = 0;
220 for (uint32_t index = 0; index < 2; index++) {
221 auto operand_inst = ir_context->get_def_use_mgr()->GetDef(
222 message_.in_operand_id(index));
223 if (!operand_inst || !operand_inst->type_id()) {
224 return 0;
225 }
226 auto operand_type =
227 ir_context->get_type_mgr()->GetType(operand_inst->type_id());
228 if (!(operand_type->AsInteger() ||
229 (operand_type->AsVector() &&
230 operand_type->AsVector()->element_type()->AsInteger()))) {
231 return 0;
232 }
233 uint32_t operand_width =
234 operand_type->AsInteger()
235 ? 1
236 : operand_type->AsVector()->element_count();
237 if (index == 0) {
238 first_operand_width = operand_width;
239 first_operand_type_id = operand_inst->type_id();
240 } else {
241 assert(first_operand_width != 0 &&
242 "The first operand should have been processed.");
243 if (operand_width != first_operand_width) {
244 return 0;
245 }
246 }
247 }
248 assert(first_operand_type_id != 0 &&
249 "A type must have been found for the first operand.");
250 return first_operand_type_id;
251 }
252 case SpvOpLogicalNot: {
253 if (message_.in_operand_id().size() != 1) {
254 return 0;
255 }
256 auto operand_inst =
257 ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
258 if (!operand_inst || !operand_inst->type_id()) {
259 return 0;
260 }
261 auto operand_type =
262 ir_context->get_type_mgr()->GetType(operand_inst->type_id());
263 if (!(operand_type->AsBool() ||
264 (operand_type->AsVector() &&
265 operand_type->AsVector()->element_type()->AsBool()))) {
266 return 0;
267 }
268 return operand_inst->type_id();
269 }
270 case SpvOpSNegate: {
271 if (message_.in_operand_id().size() != 1) {
272 return 0;
273 }
274 auto operand_inst =
275 ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
276 if (!operand_inst || !operand_inst->type_id()) {
277 return 0;
278 }
279 auto operand_type =
280 ir_context->get_type_mgr()->GetType(operand_inst->type_id());
281 if (!(operand_type->AsInteger() ||
282 (operand_type->AsVector() &&
283 operand_type->AsVector()->element_type()->AsInteger()))) {
284 return 0;
285 }
286 return operand_inst->type_id();
287 }
288 default:
289 assert(false && "Inappropriate opcode for equation instruction.");
290 return 0;
291 }
292 }
293
GetFreshIds() const294 std::unordered_set<uint32_t> TransformationEquationInstruction::GetFreshIds()
295 const {
296 return {message_.fresh_id()};
297 }
298
299 } // namespace fuzz
300 } // namespace spvtools
301