• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_equation_instruction.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 
TransformationEquationInstruction(protobufs::TransformationEquationInstruction message)23 TransformationEquationInstruction::TransformationEquationInstruction(
24     protobufs::TransformationEquationInstruction message)
25     : message_(std::move(message)) {}
26 
TransformationEquationInstruction(uint32_t fresh_id,spv::Op opcode,const std::vector<uint32_t> & in_operand_id,const protobufs::InstructionDescriptor & instruction_to_insert_before)27 TransformationEquationInstruction::TransformationEquationInstruction(
28     uint32_t fresh_id, spv::Op opcode,
29     const std::vector<uint32_t>& in_operand_id,
30     const protobufs::InstructionDescriptor& instruction_to_insert_before) {
31   message_.set_fresh_id(fresh_id);
32   message_.set_opcode(uint32_t(opcode));
33   for (auto id : in_operand_id) {
34     message_.add_in_operand_id(id);
35   }
36   *message_.mutable_instruction_to_insert_before() =
37       instruction_to_insert_before;
38 }
39 
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const40 bool TransformationEquationInstruction::IsApplicable(
41     opt::IRContext* ir_context,
42     const TransformationContext& transformation_context) const {
43   // The result id must be fresh.
44   if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) {
45     return false;
46   }
47 
48   // The instruction to insert before must exist.
49   auto insert_before =
50       FindInstruction(message_.instruction_to_insert_before(), ir_context);
51   if (!insert_before) {
52     return false;
53   }
54   // The input ids must all exist, not be OpUndef, not be irrelevant, and be
55   // available before this instruction.
56   for (auto id : message_.in_operand_id()) {
57     auto inst = ir_context->get_def_use_mgr()->GetDef(id);
58     if (!inst) {
59       return false;
60     }
61     if (inst->opcode() == spv::Op::OpUndef) {
62       return false;
63     }
64     if (transformation_context.GetFactManager()->IdIsIrrelevant(id)) {
65       return false;
66     }
67     if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before,
68                                                     id)) {
69       return false;
70     }
71   }
72 
73   return MaybeGetResultTypeId(ir_context) != 0;
74 }
75 
Apply(opt::IRContext * ir_context,TransformationContext * transformation_context) const76 void TransformationEquationInstruction::Apply(
77     opt::IRContext* ir_context,
78     TransformationContext* transformation_context) const {
79   fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id());
80 
81   opt::Instruction::OperandList in_operands;
82   std::vector<uint32_t> rhs_id;
83   for (auto id : message_.in_operand_id()) {
84     in_operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
85     rhs_id.push_back(id);
86   }
87 
88   auto insert_before =
89       FindInstruction(message_.instruction_to_insert_before(), ir_context);
90   opt::Instruction* new_instruction =
91       insert_before->InsertBefore(MakeUnique<opt::Instruction>(
92           ir_context, static_cast<spv::Op>(message_.opcode()),
93           MaybeGetResultTypeId(ir_context), message_.fresh_id(),
94           std::move(in_operands)));
95 
96   ir_context->get_def_use_mgr()->AnalyzeInstDefUse(new_instruction);
97   ir_context->set_instr_block(new_instruction,
98                               ir_context->get_instr_block(insert_before));
99 
100   // Add an equation fact as long as the result id is not irrelevant (it could
101   // be if we are inserting into a dead block).
102   if (!transformation_context->GetFactManager()->IdIsIrrelevant(
103           message_.fresh_id())) {
104     transformation_context->GetFactManager()->AddFactIdEquation(
105         message_.fresh_id(), static_cast<spv::Op>(message_.opcode()), rhs_id);
106   }
107 }
108 
ToMessage() const109 protobufs::Transformation TransformationEquationInstruction::ToMessage() const {
110   protobufs::Transformation result;
111   *result.mutable_equation_instruction() = message_;
112   return result;
113 }
114 
MaybeGetResultTypeId(opt::IRContext * ir_context) const115 uint32_t TransformationEquationInstruction::MaybeGetResultTypeId(
116     opt::IRContext* ir_context) const {
117   auto opcode = static_cast<spv::Op>(message_.opcode());
118   switch (opcode) {
119     case spv::Op::OpConvertUToF:
120     case spv::Op::OpConvertSToF: {
121       if (message_.in_operand_id_size() != 1) {
122         return 0;
123       }
124 
125       const auto* type = ir_context->get_type_mgr()->GetType(
126           fuzzerutil::GetTypeId(ir_context, message_.in_operand_id(0)));
127       if (!type) {
128         return 0;
129       }
130 
131       if (const auto* vector = type->AsVector()) {
132         if (!vector->element_type()->AsInteger()) {
133           return 0;
134         }
135 
136         if (auto element_type_id = fuzzerutil::MaybeGetFloatType(
137                 ir_context, vector->element_type()->AsInteger()->width())) {
138           return fuzzerutil::MaybeGetVectorType(ir_context, element_type_id,
139                                                 vector->element_count());
140         }
141 
142         return 0;
143       } else {
144         if (!type->AsInteger()) {
145           return 0;
146         }
147 
148         return fuzzerutil::MaybeGetFloatType(ir_context,
149                                              type->AsInteger()->width());
150       }
151     }
152     case spv::Op::OpBitcast: {
153       if (message_.in_operand_id_size() != 1) {
154         return 0;
155       }
156 
157       const auto* operand_inst =
158           ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
159       if (!operand_inst) {
160         return 0;
161       }
162 
163       const auto* operand_type =
164           ir_context->get_type_mgr()->GetType(operand_inst->type_id());
165       if (!operand_type) {
166         return 0;
167       }
168 
169       // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3539):
170       //  The only constraint on the types of OpBitcast's parameters is that
171       //  they must have the same number of bits. Consider improving the code
172       //  below to support this in full.
173       if (const auto* vector = operand_type->AsVector()) {
174         uint32_t component_type_id;
175         if (const auto* int_type = vector->element_type()->AsInteger()) {
176           component_type_id =
177               fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
178         } else if (const auto* float_type = vector->element_type()->AsFloat()) {
179           component_type_id = fuzzerutil::MaybeGetIntegerType(
180               ir_context, float_type->width(), true);
181           if (component_type_id == 0 ||
182               fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
183                                              vector->element_count()) == 0) {
184             component_type_id = fuzzerutil::MaybeGetIntegerType(
185                 ir_context, float_type->width(), false);
186           }
187         } else {
188           assert(false && "Only vectors of numerical components are supported");
189           return 0;
190         }
191 
192         if (component_type_id == 0) {
193           return 0;
194         }
195 
196         return fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
197                                               vector->element_count());
198       } else if (const auto* int_type = operand_type->AsInteger()) {
199         return fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
200       } else if (const auto* float_type = operand_type->AsFloat()) {
201         if (auto existing_id = fuzzerutil::MaybeGetIntegerType(
202                 ir_context, float_type->width(), true)) {
203           return existing_id;
204         }
205 
206         return fuzzerutil::MaybeGetIntegerType(ir_context, float_type->width(),
207                                                false);
208       } else {
209         assert(false &&
210                "Operand is not a scalar or a vector of numerical type");
211         return 0;
212       }
213     }
214     case spv::Op::OpIAdd:
215     case spv::Op::OpISub: {
216       if (message_.in_operand_id_size() != 2) {
217         return 0;
218       }
219       uint32_t first_operand_width = 0;
220       uint32_t first_operand_type_id = 0;
221       for (uint32_t index = 0; index < 2; index++) {
222         auto operand_inst = ir_context->get_def_use_mgr()->GetDef(
223             message_.in_operand_id(index));
224         if (!operand_inst || !operand_inst->type_id()) {
225           return 0;
226         }
227         auto operand_type =
228             ir_context->get_type_mgr()->GetType(operand_inst->type_id());
229         if (!(operand_type->AsInteger() ||
230               (operand_type->AsVector() &&
231                operand_type->AsVector()->element_type()->AsInteger()))) {
232           return 0;
233         }
234         uint32_t operand_width =
235             operand_type->AsInteger()
236                 ? 1
237                 : operand_type->AsVector()->element_count();
238         if (index == 0) {
239           first_operand_width = operand_width;
240           first_operand_type_id = operand_inst->type_id();
241         } else {
242           assert(first_operand_width != 0 &&
243                  "The first operand should have been processed.");
244           if (operand_width != first_operand_width) {
245             return 0;
246           }
247         }
248       }
249       assert(first_operand_type_id != 0 &&
250              "A type must have been found for the first operand.");
251       return first_operand_type_id;
252     }
253     case spv::Op::OpLogicalNot: {
254       if (message_.in_operand_id().size() != 1) {
255         return 0;
256       }
257       auto operand_inst =
258           ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
259       if (!operand_inst || !operand_inst->type_id()) {
260         return 0;
261       }
262       auto operand_type =
263           ir_context->get_type_mgr()->GetType(operand_inst->type_id());
264       if (!(operand_type->AsBool() ||
265             (operand_type->AsVector() &&
266              operand_type->AsVector()->element_type()->AsBool()))) {
267         return 0;
268       }
269       return operand_inst->type_id();
270     }
271     case spv::Op::OpSNegate: {
272       if (message_.in_operand_id().size() != 1) {
273         return 0;
274       }
275       auto operand_inst =
276           ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
277       if (!operand_inst || !operand_inst->type_id()) {
278         return 0;
279       }
280       auto operand_type =
281           ir_context->get_type_mgr()->GetType(operand_inst->type_id());
282       if (!(operand_type->AsInteger() ||
283             (operand_type->AsVector() &&
284              operand_type->AsVector()->element_type()->AsInteger()))) {
285         return 0;
286       }
287       return operand_inst->type_id();
288     }
289     default:
290       assert(false && "Inappropriate opcode for equation instruction.");
291       return 0;
292   }
293 }
294 
GetFreshIds() const295 std::unordered_set<uint32_t> TransformationEquationInstruction::GetFreshIds()
296     const {
297   return {message_.fresh_id()};
298 }
299 
300 }  // namespace fuzz
301 }  // namespace spvtools
302