• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_equation_instruction.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 
TransformationEquationInstruction(protobufs::TransformationEquationInstruction message)23 TransformationEquationInstruction::TransformationEquationInstruction(
24     protobufs::TransformationEquationInstruction message)
25     : message_(std::move(message)) {}
26 
TransformationEquationInstruction(uint32_t fresh_id,SpvOp opcode,const std::vector<uint32_t> & in_operand_id,const protobufs::InstructionDescriptor & instruction_to_insert_before)27 TransformationEquationInstruction::TransformationEquationInstruction(
28     uint32_t fresh_id, SpvOp opcode, const std::vector<uint32_t>& in_operand_id,
29     const protobufs::InstructionDescriptor& instruction_to_insert_before) {
30   message_.set_fresh_id(fresh_id);
31   message_.set_opcode(opcode);
32   for (auto id : in_operand_id) {
33     message_.add_in_operand_id(id);
34   }
35   *message_.mutable_instruction_to_insert_before() =
36       instruction_to_insert_before;
37 }
38 
IsApplicable(opt::IRContext * ir_context,const TransformationContext & transformation_context) const39 bool TransformationEquationInstruction::IsApplicable(
40     opt::IRContext* ir_context,
41     const TransformationContext& transformation_context) const {
42   // The result id must be fresh.
43   if (!fuzzerutil::IsFreshId(ir_context, message_.fresh_id())) {
44     return false;
45   }
46 
47   // The instruction to insert before must exist.
48   auto insert_before =
49       FindInstruction(message_.instruction_to_insert_before(), ir_context);
50   if (!insert_before) {
51     return false;
52   }
53   // The input ids must all exist, not be OpUndef, not be irrelevant, and be
54   // available before this instruction.
55   for (auto id : message_.in_operand_id()) {
56     auto inst = ir_context->get_def_use_mgr()->GetDef(id);
57     if (!inst) {
58       return false;
59     }
60     if (inst->opcode() == SpvOpUndef) {
61       return false;
62     }
63     if (transformation_context.GetFactManager()->IdIsIrrelevant(id)) {
64       return false;
65     }
66     if (!fuzzerutil::IdIsAvailableBeforeInstruction(ir_context, insert_before,
67                                                     id)) {
68       return false;
69     }
70   }
71 
72   return MaybeGetResultTypeId(ir_context) != 0;
73 }
74 
Apply(opt::IRContext * ir_context,TransformationContext * transformation_context) const75 void TransformationEquationInstruction::Apply(
76     opt::IRContext* ir_context,
77     TransformationContext* transformation_context) const {
78   fuzzerutil::UpdateModuleIdBound(ir_context, message_.fresh_id());
79 
80   opt::Instruction::OperandList in_operands;
81   std::vector<uint32_t> rhs_id;
82   for (auto id : message_.in_operand_id()) {
83     in_operands.push_back({SPV_OPERAND_TYPE_ID, {id}});
84     rhs_id.push_back(id);
85   }
86 
87   auto insert_before =
88       FindInstruction(message_.instruction_to_insert_before(), ir_context);
89   opt::Instruction* new_instruction =
90       insert_before->InsertBefore(MakeUnique<opt::Instruction>(
91           ir_context, static_cast<SpvOp>(message_.opcode()),
92           MaybeGetResultTypeId(ir_context), message_.fresh_id(),
93           std::move(in_operands)));
94 
95   ir_context->get_def_use_mgr()->AnalyzeInstDefUse(new_instruction);
96   ir_context->set_instr_block(new_instruction,
97                               ir_context->get_instr_block(insert_before));
98 
99   // Add an equation fact as long as the result id is not irrelevant (it could
100   // be if we are inserting into a dead block).
101   if (!transformation_context->GetFactManager()->IdIsIrrelevant(
102           message_.fresh_id())) {
103     transformation_context->GetFactManager()->AddFactIdEquation(
104         message_.fresh_id(), static_cast<SpvOp>(message_.opcode()), rhs_id);
105   }
106 }
107 
ToMessage() const108 protobufs::Transformation TransformationEquationInstruction::ToMessage() const {
109   protobufs::Transformation result;
110   *result.mutable_equation_instruction() = message_;
111   return result;
112 }
113 
MaybeGetResultTypeId(opt::IRContext * ir_context) const114 uint32_t TransformationEquationInstruction::MaybeGetResultTypeId(
115     opt::IRContext* ir_context) const {
116   auto opcode = static_cast<SpvOp>(message_.opcode());
117   switch (opcode) {
118     case SpvOpConvertUToF:
119     case SpvOpConvertSToF: {
120       if (message_.in_operand_id_size() != 1) {
121         return 0;
122       }
123 
124       const auto* type = ir_context->get_type_mgr()->GetType(
125           fuzzerutil::GetTypeId(ir_context, message_.in_operand_id(0)));
126       if (!type) {
127         return 0;
128       }
129 
130       if (const auto* vector = type->AsVector()) {
131         if (!vector->element_type()->AsInteger()) {
132           return 0;
133         }
134 
135         if (auto element_type_id = fuzzerutil::MaybeGetFloatType(
136                 ir_context, vector->element_type()->AsInteger()->width())) {
137           return fuzzerutil::MaybeGetVectorType(ir_context, element_type_id,
138                                                 vector->element_count());
139         }
140 
141         return 0;
142       } else {
143         if (!type->AsInteger()) {
144           return 0;
145         }
146 
147         return fuzzerutil::MaybeGetFloatType(ir_context,
148                                              type->AsInteger()->width());
149       }
150     }
151     case SpvOpBitcast: {
152       if (message_.in_operand_id_size() != 1) {
153         return 0;
154       }
155 
156       const auto* operand_inst =
157           ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
158       if (!operand_inst) {
159         return 0;
160       }
161 
162       const auto* operand_type =
163           ir_context->get_type_mgr()->GetType(operand_inst->type_id());
164       if (!operand_type) {
165         return 0;
166       }
167 
168       // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3539):
169       //  The only constraint on the types of OpBitcast's parameters is that
170       //  they must have the same number of bits. Consider improving the code
171       //  below to support this in full.
172       if (const auto* vector = operand_type->AsVector()) {
173         uint32_t component_type_id;
174         if (const auto* int_type = vector->element_type()->AsInteger()) {
175           component_type_id =
176               fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
177         } else if (const auto* float_type = vector->element_type()->AsFloat()) {
178           component_type_id = fuzzerutil::MaybeGetIntegerType(
179               ir_context, float_type->width(), true);
180           if (component_type_id == 0 ||
181               fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
182                                              vector->element_count()) == 0) {
183             component_type_id = fuzzerutil::MaybeGetIntegerType(
184                 ir_context, float_type->width(), false);
185           }
186         } else {
187           assert(false && "Only vectors of numerical components are supported");
188           return 0;
189         }
190 
191         if (component_type_id == 0) {
192           return 0;
193         }
194 
195         return fuzzerutil::MaybeGetVectorType(ir_context, component_type_id,
196                                               vector->element_count());
197       } else if (const auto* int_type = operand_type->AsInteger()) {
198         return fuzzerutil::MaybeGetFloatType(ir_context, int_type->width());
199       } else if (const auto* float_type = operand_type->AsFloat()) {
200         if (auto existing_id = fuzzerutil::MaybeGetIntegerType(
201                 ir_context, float_type->width(), true)) {
202           return existing_id;
203         }
204 
205         return fuzzerutil::MaybeGetIntegerType(ir_context, float_type->width(),
206                                                false);
207       } else {
208         assert(false &&
209                "Operand is not a scalar or a vector of numerical type");
210         return 0;
211       }
212     }
213     case SpvOpIAdd:
214     case SpvOpISub: {
215       if (message_.in_operand_id_size() != 2) {
216         return 0;
217       }
218       uint32_t first_operand_width = 0;
219       uint32_t first_operand_type_id = 0;
220       for (uint32_t index = 0; index < 2; index++) {
221         auto operand_inst = ir_context->get_def_use_mgr()->GetDef(
222             message_.in_operand_id(index));
223         if (!operand_inst || !operand_inst->type_id()) {
224           return 0;
225         }
226         auto operand_type =
227             ir_context->get_type_mgr()->GetType(operand_inst->type_id());
228         if (!(operand_type->AsInteger() ||
229               (operand_type->AsVector() &&
230                operand_type->AsVector()->element_type()->AsInteger()))) {
231           return 0;
232         }
233         uint32_t operand_width =
234             operand_type->AsInteger()
235                 ? 1
236                 : operand_type->AsVector()->element_count();
237         if (index == 0) {
238           first_operand_width = operand_width;
239           first_operand_type_id = operand_inst->type_id();
240         } else {
241           assert(first_operand_width != 0 &&
242                  "The first operand should have been processed.");
243           if (operand_width != first_operand_width) {
244             return 0;
245           }
246         }
247       }
248       assert(first_operand_type_id != 0 &&
249              "A type must have been found for the first operand.");
250       return first_operand_type_id;
251     }
252     case SpvOpLogicalNot: {
253       if (message_.in_operand_id().size() != 1) {
254         return 0;
255       }
256       auto operand_inst =
257           ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
258       if (!operand_inst || !operand_inst->type_id()) {
259         return 0;
260       }
261       auto operand_type =
262           ir_context->get_type_mgr()->GetType(operand_inst->type_id());
263       if (!(operand_type->AsBool() ||
264             (operand_type->AsVector() &&
265              operand_type->AsVector()->element_type()->AsBool()))) {
266         return 0;
267       }
268       return operand_inst->type_id();
269     }
270     case SpvOpSNegate: {
271       if (message_.in_operand_id().size() != 1) {
272         return 0;
273       }
274       auto operand_inst =
275           ir_context->get_def_use_mgr()->GetDef(message_.in_operand_id(0));
276       if (!operand_inst || !operand_inst->type_id()) {
277         return 0;
278       }
279       auto operand_type =
280           ir_context->get_type_mgr()->GetType(operand_inst->type_id());
281       if (!(operand_type->AsInteger() ||
282             (operand_type->AsVector() &&
283              operand_type->AsVector()->element_type()->AsInteger()))) {
284         return 0;
285       }
286       return operand_inst->type_id();
287     }
288     default:
289       assert(false && "Inappropriate opcode for equation instruction.");
290       return 0;
291   }
292 }
293 
GetFreshIds() const294 std::unordered_set<uint32_t> TransformationEquationInstruction::GetFreshIds()
295     const {
296   return {message_.fresh_id()};
297 }
298 
299 }  // namespace fuzz
300 }  // namespace spvtools
301