1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
16
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19
20 namespace spvtools {
21 namespace fuzz {
22
23 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(protobufs::TransformationReplaceLinearAlgebraInstruction message)24 TransformationReplaceLinearAlgebraInstruction(
25 protobufs::TransformationReplaceLinearAlgebraInstruction message)
26 : message_(std::move(message)) {}
27
28 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(const std::vector<uint32_t> & fresh_ids,const protobufs::InstructionDescriptor & instruction_descriptor)29 TransformationReplaceLinearAlgebraInstruction(
30 const std::vector<uint32_t>& fresh_ids,
31 const protobufs::InstructionDescriptor& instruction_descriptor) {
32 for (auto fresh_id : fresh_ids) {
33 message_.add_fresh_ids(fresh_id);
34 }
35 *message_.mutable_instruction_descriptor() = instruction_descriptor;
36 }
37
IsApplicable(opt::IRContext * ir_context,const TransformationContext &) const38 bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
39 opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
40 auto instruction =
41 FindInstruction(message_.instruction_descriptor(), ir_context);
42
43 // It must be a linear algebra instruction.
44 if (!spvOpcodeIsLinearAlgebra(instruction->opcode())) {
45 return false;
46 }
47
48 // |message_.fresh_ids.size| must be the exact number of fresh ids needed to
49 // apply the transformation.
50 if (static_cast<uint32_t>(message_.fresh_ids().size()) !=
51 GetRequiredFreshIdCount(ir_context, instruction)) {
52 return false;
53 }
54
55 // All ids in |message_.fresh_ids| must be fresh.
56 for (uint32_t fresh_id : message_.fresh_ids()) {
57 if (!fuzzerutil::IsFreshId(ir_context, fresh_id)) {
58 return false;
59 }
60 }
61
62 return true;
63 }
64
Apply(opt::IRContext * ir_context,TransformationContext *) const65 void TransformationReplaceLinearAlgebraInstruction::Apply(
66 opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
67 auto linear_algebra_instruction =
68 FindInstruction(message_.instruction_descriptor(), ir_context);
69
70 switch (linear_algebra_instruction->opcode()) {
71 case spv::Op::OpTranspose:
72 ReplaceOpTranspose(ir_context, linear_algebra_instruction);
73 break;
74 case spv::Op::OpVectorTimesScalar:
75 ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
76 break;
77 case spv::Op::OpMatrixTimesScalar:
78 ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
79 break;
80 case spv::Op::OpVectorTimesMatrix:
81 ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
82 break;
83 case spv::Op::OpMatrixTimesVector:
84 ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction);
85 break;
86 case spv::Op::OpMatrixTimesMatrix:
87 ReplaceOpMatrixTimesMatrix(ir_context, linear_algebra_instruction);
88 break;
89 case spv::Op::OpOuterProduct:
90 ReplaceOpOuterProduct(ir_context, linear_algebra_instruction);
91 break;
92 case spv::Op::OpDot:
93 ReplaceOpDot(ir_context, linear_algebra_instruction);
94 break;
95 default:
96 assert(false && "Should be unreachable.");
97 break;
98 }
99
100 ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
101 }
102
103 protobufs::Transformation
ToMessage() const104 TransformationReplaceLinearAlgebraInstruction::ToMessage() const {
105 protobufs::Transformation result;
106 *result.mutable_replace_linear_algebra_instruction() = message_;
107 return result;
108 }
109
GetRequiredFreshIdCount(opt::IRContext * ir_context,opt::Instruction * instruction)110 uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
111 opt::IRContext* ir_context, opt::Instruction* instruction) {
112 // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
113 // Right now we only support certain operations.
114 switch (instruction->opcode()) {
115 case spv::Op::OpTranspose: {
116 // For each matrix row, |2 * matrix_column_count| OpCompositeExtract and 1
117 // OpCompositeConstruct will be inserted.
118 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
119 instruction->GetSingleWordInOperand(0));
120 uint32_t matrix_column_count =
121 ir_context->get_type_mgr()
122 ->GetType(matrix_instruction->type_id())
123 ->AsMatrix()
124 ->element_count();
125 uint32_t matrix_row_count = ir_context->get_type_mgr()
126 ->GetType(matrix_instruction->type_id())
127 ->AsMatrix()
128 ->element_type()
129 ->AsVector()
130 ->element_count();
131 return matrix_row_count * (2 * matrix_column_count + 1);
132 }
133 case spv::Op::OpVectorTimesScalar:
134 // For each vector component, 1 OpCompositeExtract and 1 OpFMul will be
135 // inserted.
136 return 2 *
137 ir_context->get_type_mgr()
138 ->GetType(ir_context->get_def_use_mgr()
139 ->GetDef(instruction->GetSingleWordInOperand(0))
140 ->type_id())
141 ->AsVector()
142 ->element_count();
143 case spv::Op::OpMatrixTimesScalar: {
144 // For each matrix column, |1 + column.size| OpCompositeExtract,
145 // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be
146 // inserted.
147 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
148 instruction->GetSingleWordInOperand(0));
149 auto matrix_type =
150 ir_context->get_type_mgr()->GetType(matrix_instruction->type_id());
151 return 2 * matrix_type->AsMatrix()->element_count() *
152 (1 + matrix_type->AsMatrix()
153 ->element_type()
154 ->AsVector()
155 ->element_count());
156 }
157 case spv::Op::OpVectorTimesMatrix: {
158 // For each vector component, 1 OpCompositeExtract instruction will be
159 // inserted. For each matrix column, |1 + vector_component_count|
160 // OpCompositeExtract, |vector_component_count| OpFMul and
161 // |vector_component_count - 1| OpFAdd instructions will be inserted.
162 auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
163 instruction->GetSingleWordInOperand(0));
164 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
165 instruction->GetSingleWordInOperand(1));
166 uint32_t vector_component_count =
167 ir_context->get_type_mgr()
168 ->GetType(vector_instruction->type_id())
169 ->AsVector()
170 ->element_count();
171 uint32_t matrix_column_count =
172 ir_context->get_type_mgr()
173 ->GetType(matrix_instruction->type_id())
174 ->AsMatrix()
175 ->element_count();
176 return vector_component_count * (3 * matrix_column_count + 1);
177 }
178 case spv::Op::OpMatrixTimesVector: {
179 // For each matrix column, |1 + matrix_row_count| OpCompositeExtract
180 // will be inserted. For each matrix row, |matrix_column_count| OpFMul and
181 // |matrix_column_count - 1| OpFAdd instructions will be inserted. For
182 // each vector component, 1 OpCompositeExtract instruction will be
183 // inserted.
184 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
185 instruction->GetSingleWordInOperand(0));
186 uint32_t matrix_column_count =
187 ir_context->get_type_mgr()
188 ->GetType(matrix_instruction->type_id())
189 ->AsMatrix()
190 ->element_count();
191 uint32_t matrix_row_count = ir_context->get_type_mgr()
192 ->GetType(matrix_instruction->type_id())
193 ->AsMatrix()
194 ->element_type()
195 ->AsVector()
196 ->element_count();
197 return 3 * matrix_column_count * matrix_row_count +
198 2 * matrix_column_count - matrix_row_count;
199 }
200 case spv::Op::OpMatrixTimesMatrix: {
201 // For each matrix 2 column, 1 OpCompositeExtract, 1 OpCompositeConstruct,
202 // |3 * matrix_1_row_count * matrix_1_column_count| OpCompositeExtract,
203 // |matrix_1_row_count * matrix_1_column_count| OpFMul,
204 // |matrix_1_row_count * (matrix_1_column_count - 1)| OpFAdd instructions
205 // will be inserted.
206 auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
207 instruction->GetSingleWordInOperand(0));
208 uint32_t matrix_1_column_count =
209 ir_context->get_type_mgr()
210 ->GetType(matrix_1_instruction->type_id())
211 ->AsMatrix()
212 ->element_count();
213 uint32_t matrix_1_row_count =
214 ir_context->get_type_mgr()
215 ->GetType(matrix_1_instruction->type_id())
216 ->AsMatrix()
217 ->element_type()
218 ->AsVector()
219 ->element_count();
220
221 auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
222 instruction->GetSingleWordInOperand(1));
223 uint32_t matrix_2_column_count =
224 ir_context->get_type_mgr()
225 ->GetType(matrix_2_instruction->type_id())
226 ->AsMatrix()
227 ->element_count();
228 return matrix_2_column_count *
229 (2 + matrix_1_row_count * (5 * matrix_1_column_count - 1));
230 }
231 case spv::Op::OpOuterProduct: {
232 // For each |vector_2| component, |vector_1_component_count + 1|
233 // OpCompositeExtract, |vector_1_component_count| OpFMul and 1
234 // OpCompositeConstruct instructions will be inserted.
235 auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
236 instruction->GetSingleWordInOperand(0));
237 auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
238 instruction->GetSingleWordInOperand(1));
239 uint32_t vector_1_component_count =
240 ir_context->get_type_mgr()
241 ->GetType(vector_1_instruction->type_id())
242 ->AsVector()
243 ->element_count();
244 uint32_t vector_2_component_count =
245 ir_context->get_type_mgr()
246 ->GetType(vector_2_instruction->type_id())
247 ->AsVector()
248 ->element_count();
249 return 2 * vector_2_component_count * (vector_1_component_count + 1);
250 }
251 case spv::Op::OpDot:
252 // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
253 // will be inserted. The first two OpFMul instructions will result the
254 // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
255 // OpFAdd will be inserted. The last OpFAdd instruction is got by changing
256 // the OpDot instruction.
257 return 4 * ir_context->get_type_mgr()
258 ->GetType(
259 ir_context->get_def_use_mgr()
260 ->GetDef(instruction->GetSingleWordInOperand(0))
261 ->type_id())
262 ->AsVector()
263 ->element_count() -
264 2;
265 default:
266 assert(false && "Unsupported linear algebra instruction.");
267 return 0;
268 }
269 }
270
ReplaceOpTranspose(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const271 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpTranspose(
272 opt::IRContext* ir_context,
273 opt::Instruction* linear_algebra_instruction) const {
274 // Gets OpTranspose instruction information.
275 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
276 linear_algebra_instruction->GetSingleWordInOperand(0));
277 uint32_t matrix_column_count = ir_context->get_type_mgr()
278 ->GetType(matrix_instruction->type_id())
279 ->AsMatrix()
280 ->element_count();
281 auto matrix_column_type = ir_context->get_type_mgr()
282 ->GetType(matrix_instruction->type_id())
283 ->AsMatrix()
284 ->element_type();
285 auto matrix_column_component_type =
286 matrix_column_type->AsVector()->element_type();
287 uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
288 auto resulting_matrix_column_type =
289 ir_context->get_type_mgr()
290 ->GetType(linear_algebra_instruction->type_id())
291 ->AsMatrix()
292 ->element_type();
293
294 uint32_t fresh_id_index = 0;
295 std::vector<uint32_t> result_column_ids(matrix_row_count);
296 for (uint32_t i = 0; i < matrix_row_count; i++) {
297 std::vector<uint32_t> column_component_ids(matrix_column_count);
298 for (uint32_t j = 0; j < matrix_column_count; j++) {
299 // Extracts the matrix column.
300 uint32_t matrix_column_id = message_.fresh_ids(fresh_id_index++);
301 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
302 ir_context, spv::Op::OpCompositeExtract,
303 ir_context->get_type_mgr()->GetId(matrix_column_type),
304 matrix_column_id,
305 opt::Instruction::OperandList(
306 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
307 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
308
309 // Extracts the matrix column component.
310 column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
311 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
312 ir_context, spv::Op::OpCompositeExtract,
313 ir_context->get_type_mgr()->GetId(matrix_column_component_type),
314 column_component_ids[j],
315 opt::Instruction::OperandList(
316 {{SPV_OPERAND_TYPE_ID, {matrix_column_id}},
317 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
318 }
319
320 // Inserts the resulting matrix column.
321 opt::Instruction::OperandList in_operands;
322 for (auto& column_component_id : column_component_ids) {
323 in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
324 }
325 result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
326 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
327 ir_context, spv::Op::OpCompositeConstruct,
328 ir_context->get_type_mgr()->GetId(resulting_matrix_column_type),
329 result_column_ids[i], opt::Instruction::OperandList(in_operands)));
330 }
331
332 // The OpTranspose instruction is changed to an OpCompositeConstruct
333 // instruction.
334 linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
335 linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
336 for (uint32_t i = 1; i < result_column_ids.size(); i++) {
337 linear_algebra_instruction->AddOperand(
338 {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
339 }
340
341 fuzzerutil::UpdateModuleIdBound(
342 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
343 }
344
ReplaceOpVectorTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const345 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
346 opt::IRContext* ir_context,
347 opt::Instruction* linear_algebra_instruction) const {
348 // Gets OpVectorTimesScalar in operands.
349 auto vector = ir_context->get_def_use_mgr()->GetDef(
350 linear_algebra_instruction->GetSingleWordInOperand(0));
351 auto scalar = ir_context->get_def_use_mgr()->GetDef(
352 linear_algebra_instruction->GetSingleWordInOperand(1));
353
354 uint32_t vector_component_count = ir_context->get_type_mgr()
355 ->GetType(vector->type_id())
356 ->AsVector()
357 ->element_count();
358 std::vector<uint32_t> float_multiplication_ids(vector_component_count);
359 uint32_t fresh_id_index = 0;
360
361 for (uint32_t i = 0; i < vector_component_count; i++) {
362 // Extracts |vector| component.
363 uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++);
364 fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id);
365 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
366 ir_context, spv::Op::OpCompositeExtract, scalar->type_id(),
367 vector_extract_id,
368 opt::Instruction::OperandList(
369 {{SPV_OPERAND_TYPE_ID, {vector->result_id()}},
370 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
371
372 // Multiplies the |vector| component with the |scalar|.
373 uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++);
374 float_multiplication_ids[i] = float_multiplication_id;
375 fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id);
376 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
377 ir_context, spv::Op::OpFMul, scalar->type_id(), float_multiplication_id,
378 opt::Instruction::OperandList(
379 {{SPV_OPERAND_TYPE_ID, {vector_extract_id}},
380 {SPV_OPERAND_TYPE_ID, {scalar->result_id()}}})));
381 }
382
383 // The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct
384 // instruction.
385 linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
386 linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
387 linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
388 for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) {
389 linear_algebra_instruction->AddOperand(
390 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}});
391 }
392 }
393
ReplaceOpMatrixTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const394 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar(
395 opt::IRContext* ir_context,
396 opt::Instruction* linear_algebra_instruction) const {
397 // Gets OpMatrixTimesScalar in operands.
398 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
399 linear_algebra_instruction->GetSingleWordInOperand(0));
400 auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef(
401 linear_algebra_instruction->GetSingleWordInOperand(1));
402
403 // Gets matrix information.
404 uint32_t matrix_column_count = ir_context->get_type_mgr()
405 ->GetType(matrix_instruction->type_id())
406 ->AsMatrix()
407 ->element_count();
408 auto matrix_column_type = ir_context->get_type_mgr()
409 ->GetType(matrix_instruction->type_id())
410 ->AsMatrix()
411 ->element_type();
412 uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count();
413
414 std::vector<uint32_t> composite_construct_ids(matrix_column_count);
415 uint32_t fresh_id_index = 0;
416
417 for (uint32_t i = 0; i < matrix_column_count; i++) {
418 // Extracts |matrix| column.
419 uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
420 fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id);
421 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
422 ir_context, spv::Op::OpCompositeExtract,
423 ir_context->get_type_mgr()->GetId(matrix_column_type),
424 matrix_extract_id,
425 opt::Instruction::OperandList(
426 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
427 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
428
429 std::vector<uint32_t> float_multiplication_ids(matrix_column_size);
430
431 for (uint32_t j = 0; j < matrix_column_size; j++) {
432 // Extracts |column| component.
433 uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
434 fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id);
435 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
436 ir_context, spv::Op::OpCompositeExtract,
437 scalar_instruction->type_id(), column_extract_id,
438 opt::Instruction::OperandList(
439 {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
440 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
441
442 // Multiplies the |column| component with the |scalar|.
443 float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
444 fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]);
445 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
446 ir_context, spv::Op::OpFMul, scalar_instruction->type_id(),
447 float_multiplication_ids[j],
448 opt::Instruction::OperandList(
449 {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
450 {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}})));
451 }
452
453 // Constructs a new column multiplied by |scalar|.
454 opt::Instruction::OperandList composite_construct_in_operands;
455 for (uint32_t& float_multiplication_id : float_multiplication_ids) {
456 composite_construct_in_operands.push_back(
457 {SPV_OPERAND_TYPE_ID, {float_multiplication_id}});
458 }
459 composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++);
460 fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]);
461 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
462 ir_context, spv::Op::OpCompositeConstruct,
463 ir_context->get_type_mgr()->GetId(matrix_column_type),
464 composite_construct_ids[i], composite_construct_in_operands));
465 }
466
467 // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct
468 // instruction.
469 linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
470 linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]});
471 linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]});
472 for (uint32_t i = 2; i < composite_construct_ids.size(); i++) {
473 linear_algebra_instruction->AddOperand(
474 {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}});
475 }
476 }
477
ReplaceOpVectorTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const478 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix(
479 opt::IRContext* ir_context,
480 opt::Instruction* linear_algebra_instruction) const {
481 // Gets vector information.
482 auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
483 linear_algebra_instruction->GetSingleWordInOperand(0));
484 uint32_t vector_component_count = ir_context->get_type_mgr()
485 ->GetType(vector_instruction->type_id())
486 ->AsVector()
487 ->element_count();
488 auto vector_component_type = ir_context->get_type_mgr()
489 ->GetType(vector_instruction->type_id())
490 ->AsVector()
491 ->element_type();
492
493 // Extracts vector components.
494 uint32_t fresh_id_index = 0;
495 std::vector<uint32_t> vector_component_ids(vector_component_count);
496 for (uint32_t i = 0; i < vector_component_count; i++) {
497 vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
498 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
499 ir_context, spv::Op::OpCompositeExtract,
500 ir_context->get_type_mgr()->GetId(vector_component_type),
501 vector_component_ids[i],
502 opt::Instruction::OperandList(
503 {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
504 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
505 }
506
507 // Gets matrix information.
508 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
509 linear_algebra_instruction->GetSingleWordInOperand(1));
510 uint32_t matrix_column_count = ir_context->get_type_mgr()
511 ->GetType(matrix_instruction->type_id())
512 ->AsMatrix()
513 ->element_count();
514 auto matrix_column_type = ir_context->get_type_mgr()
515 ->GetType(matrix_instruction->type_id())
516 ->AsMatrix()
517 ->element_type();
518
519 std::vector<uint32_t> result_component_ids(matrix_column_count);
520 for (uint32_t i = 0; i < matrix_column_count; i++) {
521 // Extracts matrix column.
522 uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
523 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
524 ir_context, spv::Op::OpCompositeExtract,
525 ir_context->get_type_mgr()->GetId(matrix_column_type),
526 matrix_extract_id,
527 opt::Instruction::OperandList(
528 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
529 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
530
531 std::vector<uint32_t> float_multiplication_ids(vector_component_count);
532 for (uint32_t j = 0; j < vector_component_count; j++) {
533 // Extracts column component.
534 uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
535 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
536 ir_context, spv::Op::OpCompositeExtract,
537 ir_context->get_type_mgr()->GetId(vector_component_type),
538 column_extract_id,
539 opt::Instruction::OperandList(
540 {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
541 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
542
543 // Multiplies corresponding vector and column components.
544 float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
545 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
546 ir_context, spv::Op::OpFMul,
547 ir_context->get_type_mgr()->GetId(vector_component_type),
548 float_multiplication_ids[j],
549 opt::Instruction::OperandList(
550 {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}},
551 {SPV_OPERAND_TYPE_ID, {column_extract_id}}})));
552 }
553
554 // Adds the multiplication results.
555 std::vector<uint32_t> float_add_ids;
556 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
557 float_add_ids.push_back(float_add_id);
558 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
559 ir_context, spv::Op::OpFAdd,
560 ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
561 opt::Instruction::OperandList(
562 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
563 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
564 for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
565 float_add_id = message_.fresh_ids(fresh_id_index++);
566 float_add_ids.push_back(float_add_id);
567 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
568 ir_context, spv::Op::OpFAdd,
569 ir_context->get_type_mgr()->GetId(vector_component_type),
570 float_add_id,
571 opt::Instruction::OperandList(
572 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
573 {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
574 }
575
576 result_component_ids[i] = float_add_ids.back();
577 }
578
579 // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct
580 // instruction.
581 linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
582 linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
583 linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
584 for (uint32_t i = 2; i < result_component_ids.size(); i++) {
585 linear_algebra_instruction->AddOperand(
586 {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
587 }
588
589 fuzzerutil::UpdateModuleIdBound(
590 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
591 }
592
ReplaceOpMatrixTimesVector(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const593 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector(
594 opt::IRContext* ir_context,
595 opt::Instruction* linear_algebra_instruction) const {
596 // Gets matrix information.
597 auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
598 linear_algebra_instruction->GetSingleWordInOperand(0));
599 uint32_t matrix_column_count = ir_context->get_type_mgr()
600 ->GetType(matrix_instruction->type_id())
601 ->AsMatrix()
602 ->element_count();
603 auto matrix_column_type = ir_context->get_type_mgr()
604 ->GetType(matrix_instruction->type_id())
605 ->AsMatrix()
606 ->element_type();
607 uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
608
609 // Extracts matrix columns.
610 uint32_t fresh_id_index = 0;
611 std::vector<uint32_t> matrix_column_ids(matrix_column_count);
612 for (uint32_t i = 0; i < matrix_column_count; i++) {
613 matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++);
614 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
615 ir_context, spv::Op::OpCompositeExtract,
616 ir_context->get_type_mgr()->GetId(matrix_column_type),
617 matrix_column_ids[i],
618 opt::Instruction::OperandList(
619 {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
620 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
621 }
622
623 // Gets vector information.
624 auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
625 linear_algebra_instruction->GetSingleWordInOperand(1));
626 auto vector_component_type = ir_context->get_type_mgr()
627 ->GetType(vector_instruction->type_id())
628 ->AsVector()
629 ->element_type();
630
631 // Extracts vector components.
632 std::vector<uint32_t> vector_component_ids(matrix_column_count);
633 for (uint32_t i = 0; i < matrix_column_count; i++) {
634 vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
635 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
636 ir_context, spv::Op::OpCompositeExtract,
637 ir_context->get_type_mgr()->GetId(vector_component_type),
638 vector_component_ids[i],
639 opt::Instruction::OperandList(
640 {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
641 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
642 }
643
644 std::vector<uint32_t> result_component_ids(matrix_row_count);
645 for (uint32_t i = 0; i < matrix_row_count; i++) {
646 std::vector<uint32_t> float_multiplication_ids(matrix_column_count);
647 for (uint32_t j = 0; j < matrix_column_count; j++) {
648 // Extracts column component.
649 uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
650 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
651 ir_context, spv::Op::OpCompositeExtract,
652 ir_context->get_type_mgr()->GetId(vector_component_type),
653 column_extract_id,
654 opt::Instruction::OperandList(
655 {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}},
656 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
657
658 // Multiplies corresponding vector and column components.
659 float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
660 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
661 ir_context, spv::Op::OpFMul,
662 ir_context->get_type_mgr()->GetId(vector_component_type),
663 float_multiplication_ids[j],
664 opt::Instruction::OperandList(
665 {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
666 {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}})));
667 }
668
669 // Adds the multiplication results.
670 std::vector<uint32_t> float_add_ids;
671 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
672 float_add_ids.push_back(float_add_id);
673 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
674 ir_context, spv::Op::OpFAdd,
675 ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
676 opt::Instruction::OperandList(
677 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
678 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
679 for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
680 float_add_id = message_.fresh_ids(fresh_id_index++);
681 float_add_ids.push_back(float_add_id);
682 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
683 ir_context, spv::Op::OpFAdd,
684 ir_context->get_type_mgr()->GetId(vector_component_type),
685 float_add_id,
686 opt::Instruction::OperandList(
687 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
688 {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
689 }
690
691 result_component_ids[i] = float_add_ids.back();
692 }
693
694 // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct
695 // instruction.
696 linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
697 linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
698 linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
699 for (uint32_t i = 2; i < result_component_ids.size(); i++) {
700 linear_algebra_instruction->AddOperand(
701 {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
702 }
703
704 fuzzerutil::UpdateModuleIdBound(
705 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
706 }
707
ReplaceOpMatrixTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const708 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesMatrix(
709 opt::IRContext* ir_context,
710 opt::Instruction* linear_algebra_instruction) const {
711 // Gets matrix 1 information.
712 auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
713 linear_algebra_instruction->GetSingleWordInOperand(0));
714 uint32_t matrix_1_column_count =
715 ir_context->get_type_mgr()
716 ->GetType(matrix_1_instruction->type_id())
717 ->AsMatrix()
718 ->element_count();
719 auto matrix_1_column_type = ir_context->get_type_mgr()
720 ->GetType(matrix_1_instruction->type_id())
721 ->AsMatrix()
722 ->element_type();
723 auto matrix_1_column_component_type =
724 matrix_1_column_type->AsVector()->element_type();
725 uint32_t matrix_1_row_count =
726 matrix_1_column_type->AsVector()->element_count();
727
728 // Gets matrix 2 information.
729 auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
730 linear_algebra_instruction->GetSingleWordInOperand(1));
731 uint32_t matrix_2_column_count =
732 ir_context->get_type_mgr()
733 ->GetType(matrix_2_instruction->type_id())
734 ->AsMatrix()
735 ->element_count();
736 auto matrix_2_column_type = ir_context->get_type_mgr()
737 ->GetType(matrix_2_instruction->type_id())
738 ->AsMatrix()
739 ->element_type();
740
741 uint32_t fresh_id_index = 0;
742 std::vector<uint32_t> result_column_ids(matrix_2_column_count);
743 for (uint32_t i = 0; i < matrix_2_column_count; i++) {
744 // Extracts matrix 2 column.
745 uint32_t matrix_2_column_id = message_.fresh_ids(fresh_id_index++);
746 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
747 ir_context, spv::Op::OpCompositeExtract,
748 ir_context->get_type_mgr()->GetId(matrix_2_column_type),
749 matrix_2_column_id,
750 opt::Instruction::OperandList(
751 {{SPV_OPERAND_TYPE_ID, {matrix_2_instruction->result_id()}},
752 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
753
754 std::vector<uint32_t> column_component_ids(matrix_1_row_count);
755 for (uint32_t j = 0; j < matrix_1_row_count; j++) {
756 std::vector<uint32_t> float_multiplication_ids(matrix_1_column_count);
757 for (uint32_t k = 0; k < matrix_1_column_count; k++) {
758 // Extracts matrix 1 column.
759 uint32_t matrix_1_column_id = message_.fresh_ids(fresh_id_index++);
760 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
761 ir_context, spv::Op::OpCompositeExtract,
762 ir_context->get_type_mgr()->GetId(matrix_1_column_type),
763 matrix_1_column_id,
764 opt::Instruction::OperandList(
765 {{SPV_OPERAND_TYPE_ID, {matrix_1_instruction->result_id()}},
766 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
767
768 // Extracts matrix 1 column component.
769 uint32_t matrix_1_column_component_id =
770 message_.fresh_ids(fresh_id_index++);
771 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
772 ir_context, spv::Op::OpCompositeExtract,
773 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
774 matrix_1_column_component_id,
775 opt::Instruction::OperandList(
776 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_id}},
777 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
778
779 // Extracts matrix 2 column component.
780 uint32_t matrix_2_column_component_id =
781 message_.fresh_ids(fresh_id_index++);
782 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
783 ir_context, spv::Op::OpCompositeExtract,
784 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
785 matrix_2_column_component_id,
786 opt::Instruction::OperandList(
787 {{SPV_OPERAND_TYPE_ID, {matrix_2_column_id}},
788 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
789
790 // Multiplies corresponding matrix 1 and matrix 2 column components.
791 float_multiplication_ids[k] = message_.fresh_ids(fresh_id_index++);
792 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
793 ir_context, spv::Op::OpFMul,
794 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
795 float_multiplication_ids[k],
796 opt::Instruction::OperandList(
797 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_component_id}},
798 {SPV_OPERAND_TYPE_ID, {matrix_2_column_component_id}}})));
799 }
800
801 // Adds the multiplication results.
802 std::vector<uint32_t> float_add_ids;
803 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
804 float_add_ids.push_back(float_add_id);
805 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
806 ir_context, spv::Op::OpFAdd,
807 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
808 float_add_id,
809 opt::Instruction::OperandList(
810 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
811 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
812 for (uint32_t k = 2; k < float_multiplication_ids.size(); k++) {
813 float_add_id = message_.fresh_ids(fresh_id_index++);
814 float_add_ids.push_back(float_add_id);
815 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
816 ir_context, spv::Op::OpFAdd,
817 ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
818 float_add_id,
819 opt::Instruction::OperandList(
820 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[k]}},
821 {SPV_OPERAND_TYPE_ID, {float_add_ids[k - 2]}}})));
822 }
823
824 column_component_ids[j] = float_add_ids.back();
825 }
826
827 // Inserts the resulting matrix column.
828 opt::Instruction::OperandList in_operands;
829 for (auto& column_component_id : column_component_ids) {
830 in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
831 }
832 result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
833 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
834 ir_context, spv::Op::OpCompositeConstruct,
835 ir_context->get_type_mgr()->GetId(matrix_1_column_type),
836 result_column_ids[i], opt::Instruction::OperandList(in_operands)));
837 }
838
839 // The OpMatrixTimesMatrix instruction is changed to an OpCompositeConstruct
840 // instruction.
841 linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
842 linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
843 linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
844 for (uint32_t i = 2; i < result_column_ids.size(); i++) {
845 linear_algebra_instruction->AddOperand(
846 {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
847 }
848
849 fuzzerutil::UpdateModuleIdBound(
850 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
851 }
852
ReplaceOpOuterProduct(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const853 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpOuterProduct(
854 opt::IRContext* ir_context,
855 opt::Instruction* linear_algebra_instruction) const {
856 // Gets vector 1 information.
857 auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
858 linear_algebra_instruction->GetSingleWordInOperand(0));
859 uint32_t vector_1_component_count =
860 ir_context->get_type_mgr()
861 ->GetType(vector_1_instruction->type_id())
862 ->AsVector()
863 ->element_count();
864 auto vector_1_component_type = ir_context->get_type_mgr()
865 ->GetType(vector_1_instruction->type_id())
866 ->AsVector()
867 ->element_type();
868
869 // Gets vector 2 information.
870 auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
871 linear_algebra_instruction->GetSingleWordInOperand(1));
872 uint32_t vector_2_component_count =
873 ir_context->get_type_mgr()
874 ->GetType(vector_2_instruction->type_id())
875 ->AsVector()
876 ->element_count();
877
878 uint32_t fresh_id_index = 0;
879 std::vector<uint32_t> result_column_ids(vector_2_component_count);
880 for (uint32_t i = 0; i < vector_2_component_count; i++) {
881 // Extracts |vector_2| component.
882 uint32_t vector_2_component_id = message_.fresh_ids(fresh_id_index++);
883 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
884 ir_context, spv::Op::OpCompositeExtract,
885 ir_context->get_type_mgr()->GetId(vector_1_component_type),
886 vector_2_component_id,
887 opt::Instruction::OperandList(
888 {{SPV_OPERAND_TYPE_ID, {vector_2_instruction->result_id()}},
889 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
890
891 std::vector<uint32_t> column_component_ids(vector_1_component_count);
892 for (uint32_t j = 0; j < vector_1_component_count; j++) {
893 // Extracts |vector_1| component.
894 uint32_t vector_1_component_id = message_.fresh_ids(fresh_id_index++);
895 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
896 ir_context, spv::Op::OpCompositeExtract,
897 ir_context->get_type_mgr()->GetId(vector_1_component_type),
898 vector_1_component_id,
899 opt::Instruction::OperandList(
900 {{SPV_OPERAND_TYPE_ID, {vector_1_instruction->result_id()}},
901 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
902
903 // Multiplies |vector_1| and |vector_2| components.
904 column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
905 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
906 ir_context, spv::Op::OpFMul,
907 ir_context->get_type_mgr()->GetId(vector_1_component_type),
908 column_component_ids[j],
909 opt::Instruction::OperandList(
910 {{SPV_OPERAND_TYPE_ID, {vector_2_component_id}},
911 {SPV_OPERAND_TYPE_ID, {vector_1_component_id}}})));
912 }
913
914 // Inserts the resulting matrix column.
915 opt::Instruction::OperandList in_operands;
916 for (auto& column_component_id : column_component_ids) {
917 in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
918 }
919 result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
920 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
921 ir_context, spv::Op::OpCompositeConstruct,
922 vector_1_instruction->type_id(), result_column_ids[i], in_operands));
923 }
924
925 // The OpOuterProduct instruction is changed to an OpCompositeConstruct
926 // instruction.
927 linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
928 linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
929 linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
930 for (uint32_t i = 2; i < result_column_ids.size(); i++) {
931 linear_algebra_instruction->AddOperand(
932 {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
933 }
934
935 fuzzerutil::UpdateModuleIdBound(
936 ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
937 }
938
ReplaceOpDot(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const939 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
940 opt::IRContext* ir_context,
941 opt::Instruction* linear_algebra_instruction) const {
942 // Gets OpDot in operands.
943 auto vector_1 = ir_context->get_def_use_mgr()->GetDef(
944 linear_algebra_instruction->GetSingleWordInOperand(0));
945 auto vector_2 = ir_context->get_def_use_mgr()->GetDef(
946 linear_algebra_instruction->GetSingleWordInOperand(1));
947
948 uint32_t vectors_component_count = ir_context->get_type_mgr()
949 ->GetType(vector_1->type_id())
950 ->AsVector()
951 ->element_count();
952 std::vector<uint32_t> float_multiplication_ids(vectors_component_count);
953 uint32_t fresh_id_index = 0;
954
955 for (uint32_t i = 0; i < vectors_component_count; i++) {
956 // Extracts |vector_1| component.
957 uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++);
958 fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id);
959 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
960 ir_context, spv::Op::OpCompositeExtract,
961 linear_algebra_instruction->type_id(), vector_1_extract_id,
962 opt::Instruction::OperandList(
963 {{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}},
964 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
965
966 // Extracts |vector_2| component.
967 uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++);
968 fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id);
969 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
970 ir_context, spv::Op::OpCompositeExtract,
971 linear_algebra_instruction->type_id(), vector_2_extract_id,
972 opt::Instruction::OperandList(
973 {{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}},
974 {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
975
976 // Multiplies the pair of components.
977 float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++);
978 fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]);
979 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
980 ir_context, spv::Op::OpFMul, linear_algebra_instruction->type_id(),
981 float_multiplication_ids[i],
982 opt::Instruction::OperandList(
983 {{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}},
984 {SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}})));
985 }
986
987 // If the vector has 2 components, then there will be 2 float multiplication
988 // instructions.
989 if (vectors_component_count == 2) {
990 linear_algebra_instruction->SetOpcode(spv::Op::OpFAdd);
991 linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
992 linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
993 } else {
994 // The first OpFAdd instruction has as operands the first two OpFMul
995 // instructions.
996 std::vector<uint32_t> float_add_ids;
997 uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
998 float_add_ids.push_back(float_add_id);
999 fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
1000 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1001 ir_context, spv::Op::OpFAdd, linear_algebra_instruction->type_id(),
1002 float_add_id,
1003 opt::Instruction::OperandList(
1004 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
1005 {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
1006
1007 // The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd
1008 // instruction.
1009 for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) {
1010 float_add_id = message_.fresh_ids(fresh_id_index++);
1011 fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
1012 float_add_ids.push_back(float_add_id);
1013 linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1014 ir_context, spv::Op::OpFAdd, linear_algebra_instruction->type_id(),
1015 float_add_id,
1016 opt::Instruction::OperandList(
1017 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}},
1018 {SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}})));
1019 }
1020
1021 // The last OpFAdd instruction is got by changing some of the OpDot
1022 // instruction attributes.
1023 linear_algebra_instruction->SetOpcode(spv::Op::OpFAdd);
1024 linear_algebra_instruction->SetInOperand(
1025 0, {float_multiplication_ids[float_multiplication_ids.size() - 1]});
1026 linear_algebra_instruction->SetInOperand(
1027 1, {float_add_ids[float_add_ids.size() - 1]});
1028 }
1029 }
1030
1031 std::unordered_set<uint32_t>
GetFreshIds() const1032 TransformationReplaceLinearAlgebraInstruction::GetFreshIds() const {
1033 std::unordered_set<uint32_t> result;
1034 for (auto id : message_.fresh_ids()) {
1035 result.insert(id);
1036 }
1037 return result;
1038 }
1039
1040 } // namespace fuzz
1041 } // namespace spvtools
1042