• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_replace_linear_algebra_instruction.h"
16 
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
19 
20 namespace spvtools {
21 namespace fuzz {
22 
23 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(protobufs::TransformationReplaceLinearAlgebraInstruction message)24     TransformationReplaceLinearAlgebraInstruction(
25         protobufs::TransformationReplaceLinearAlgebraInstruction message)
26     : message_(std::move(message)) {}
27 
28 TransformationReplaceLinearAlgebraInstruction::
TransformationReplaceLinearAlgebraInstruction(const std::vector<uint32_t> & fresh_ids,const protobufs::InstructionDescriptor & instruction_descriptor)29     TransformationReplaceLinearAlgebraInstruction(
30         const std::vector<uint32_t>& fresh_ids,
31         const protobufs::InstructionDescriptor& instruction_descriptor) {
32   for (auto fresh_id : fresh_ids) {
33     message_.add_fresh_ids(fresh_id);
34   }
35   *message_.mutable_instruction_descriptor() = instruction_descriptor;
36 }
37 
IsApplicable(opt::IRContext * ir_context,const TransformationContext &) const38 bool TransformationReplaceLinearAlgebraInstruction::IsApplicable(
39     opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
40   auto instruction =
41       FindInstruction(message_.instruction_descriptor(), ir_context);
42 
43   // It must be a linear algebra instruction.
44   if (!spvOpcodeIsLinearAlgebra(instruction->opcode())) {
45     return false;
46   }
47 
48   // |message_.fresh_ids.size| must be the exact number of fresh ids needed to
49   // apply the transformation.
50   if (static_cast<uint32_t>(message_.fresh_ids().size()) !=
51       GetRequiredFreshIdCount(ir_context, instruction)) {
52     return false;
53   }
54 
55   // All ids in |message_.fresh_ids| must be fresh.
56   for (uint32_t fresh_id : message_.fresh_ids()) {
57     if (!fuzzerutil::IsFreshId(ir_context, fresh_id)) {
58       return false;
59     }
60   }
61 
62   return true;
63 }
64 
Apply(opt::IRContext * ir_context,TransformationContext *) const65 void TransformationReplaceLinearAlgebraInstruction::Apply(
66     opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
67   auto linear_algebra_instruction =
68       FindInstruction(message_.instruction_descriptor(), ir_context);
69 
70   switch (linear_algebra_instruction->opcode()) {
71     case spv::Op::OpTranspose:
72       ReplaceOpTranspose(ir_context, linear_algebra_instruction);
73       break;
74     case spv::Op::OpVectorTimesScalar:
75       ReplaceOpVectorTimesScalar(ir_context, linear_algebra_instruction);
76       break;
77     case spv::Op::OpMatrixTimesScalar:
78       ReplaceOpMatrixTimesScalar(ir_context, linear_algebra_instruction);
79       break;
80     case spv::Op::OpVectorTimesMatrix:
81       ReplaceOpVectorTimesMatrix(ir_context, linear_algebra_instruction);
82       break;
83     case spv::Op::OpMatrixTimesVector:
84       ReplaceOpMatrixTimesVector(ir_context, linear_algebra_instruction);
85       break;
86     case spv::Op::OpMatrixTimesMatrix:
87       ReplaceOpMatrixTimesMatrix(ir_context, linear_algebra_instruction);
88       break;
89     case spv::Op::OpOuterProduct:
90       ReplaceOpOuterProduct(ir_context, linear_algebra_instruction);
91       break;
92     case spv::Op::OpDot:
93       ReplaceOpDot(ir_context, linear_algebra_instruction);
94       break;
95     default:
96       assert(false && "Should be unreachable.");
97       break;
98   }
99 
100   ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
101 }
102 
103 protobufs::Transformation
ToMessage() const104 TransformationReplaceLinearAlgebraInstruction::ToMessage() const {
105   protobufs::Transformation result;
106   *result.mutable_replace_linear_algebra_instruction() = message_;
107   return result;
108 }
109 
GetRequiredFreshIdCount(opt::IRContext * ir_context,opt::Instruction * instruction)110 uint32_t TransformationReplaceLinearAlgebraInstruction::GetRequiredFreshIdCount(
111     opt::IRContext* ir_context, opt::Instruction* instruction) {
112   // TODO(https://github.com/KhronosGroup/SPIRV-Tools/issues/3354):
113   // Right now we only support certain operations.
114   switch (instruction->opcode()) {
115     case spv::Op::OpTranspose: {
116       // For each matrix row, |2 * matrix_column_count| OpCompositeExtract and 1
117       // OpCompositeConstruct will be inserted.
118       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
119           instruction->GetSingleWordInOperand(0));
120       uint32_t matrix_column_count =
121           ir_context->get_type_mgr()
122               ->GetType(matrix_instruction->type_id())
123               ->AsMatrix()
124               ->element_count();
125       uint32_t matrix_row_count = ir_context->get_type_mgr()
126                                       ->GetType(matrix_instruction->type_id())
127                                       ->AsMatrix()
128                                       ->element_type()
129                                       ->AsVector()
130                                       ->element_count();
131       return matrix_row_count * (2 * matrix_column_count + 1);
132     }
133     case spv::Op::OpVectorTimesScalar:
134       // For each vector component, 1 OpCompositeExtract and 1 OpFMul will be
135       // inserted.
136       return 2 *
137              ir_context->get_type_mgr()
138                  ->GetType(ir_context->get_def_use_mgr()
139                                ->GetDef(instruction->GetSingleWordInOperand(0))
140                                ->type_id())
141                  ->AsVector()
142                  ->element_count();
143     case spv::Op::OpMatrixTimesScalar: {
144       // For each matrix column, |1 + column.size| OpCompositeExtract,
145       // |column.size| OpFMul and 1 OpCompositeConstruct instructions will be
146       // inserted.
147       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
148           instruction->GetSingleWordInOperand(0));
149       auto matrix_type =
150           ir_context->get_type_mgr()->GetType(matrix_instruction->type_id());
151       return 2 * matrix_type->AsMatrix()->element_count() *
152              (1 + matrix_type->AsMatrix()
153                       ->element_type()
154                       ->AsVector()
155                       ->element_count());
156     }
157     case spv::Op::OpVectorTimesMatrix: {
158       // For each vector component, 1 OpCompositeExtract instruction will be
159       // inserted. For each matrix column, |1 + vector_component_count|
160       // OpCompositeExtract, |vector_component_count| OpFMul and
161       // |vector_component_count - 1| OpFAdd instructions will be inserted.
162       auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
163           instruction->GetSingleWordInOperand(0));
164       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
165           instruction->GetSingleWordInOperand(1));
166       uint32_t vector_component_count =
167           ir_context->get_type_mgr()
168               ->GetType(vector_instruction->type_id())
169               ->AsVector()
170               ->element_count();
171       uint32_t matrix_column_count =
172           ir_context->get_type_mgr()
173               ->GetType(matrix_instruction->type_id())
174               ->AsMatrix()
175               ->element_count();
176       return vector_component_count * (3 * matrix_column_count + 1);
177     }
178     case spv::Op::OpMatrixTimesVector: {
179       // For each matrix column, |1 + matrix_row_count| OpCompositeExtract
180       // will be inserted. For each matrix row, |matrix_column_count| OpFMul and
181       // |matrix_column_count - 1| OpFAdd instructions will be inserted. For
182       // each vector component, 1 OpCompositeExtract instruction will be
183       // inserted.
184       auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
185           instruction->GetSingleWordInOperand(0));
186       uint32_t matrix_column_count =
187           ir_context->get_type_mgr()
188               ->GetType(matrix_instruction->type_id())
189               ->AsMatrix()
190               ->element_count();
191       uint32_t matrix_row_count = ir_context->get_type_mgr()
192                                       ->GetType(matrix_instruction->type_id())
193                                       ->AsMatrix()
194                                       ->element_type()
195                                       ->AsVector()
196                                       ->element_count();
197       return 3 * matrix_column_count * matrix_row_count +
198              2 * matrix_column_count - matrix_row_count;
199     }
200     case spv::Op::OpMatrixTimesMatrix: {
201       // For each matrix 2 column, 1 OpCompositeExtract, 1 OpCompositeConstruct,
202       // |3 * matrix_1_row_count * matrix_1_column_count| OpCompositeExtract,
203       // |matrix_1_row_count * matrix_1_column_count| OpFMul,
204       // |matrix_1_row_count * (matrix_1_column_count - 1)| OpFAdd instructions
205       // will be inserted.
206       auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
207           instruction->GetSingleWordInOperand(0));
208       uint32_t matrix_1_column_count =
209           ir_context->get_type_mgr()
210               ->GetType(matrix_1_instruction->type_id())
211               ->AsMatrix()
212               ->element_count();
213       uint32_t matrix_1_row_count =
214           ir_context->get_type_mgr()
215               ->GetType(matrix_1_instruction->type_id())
216               ->AsMatrix()
217               ->element_type()
218               ->AsVector()
219               ->element_count();
220 
221       auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
222           instruction->GetSingleWordInOperand(1));
223       uint32_t matrix_2_column_count =
224           ir_context->get_type_mgr()
225               ->GetType(matrix_2_instruction->type_id())
226               ->AsMatrix()
227               ->element_count();
228       return matrix_2_column_count *
229              (2 + matrix_1_row_count * (5 * matrix_1_column_count - 1));
230     }
231     case spv::Op::OpOuterProduct: {
232       // For each |vector_2| component, |vector_1_component_count + 1|
233       // OpCompositeExtract, |vector_1_component_count| OpFMul and 1
234       // OpCompositeConstruct instructions will be inserted.
235       auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
236           instruction->GetSingleWordInOperand(0));
237       auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
238           instruction->GetSingleWordInOperand(1));
239       uint32_t vector_1_component_count =
240           ir_context->get_type_mgr()
241               ->GetType(vector_1_instruction->type_id())
242               ->AsVector()
243               ->element_count();
244       uint32_t vector_2_component_count =
245           ir_context->get_type_mgr()
246               ->GetType(vector_2_instruction->type_id())
247               ->AsVector()
248               ->element_count();
249       return 2 * vector_2_component_count * (vector_1_component_count + 1);
250     }
251     case spv::Op::OpDot:
252       // For each pair of vector components, 2 OpCompositeExtract and 1 OpFMul
253       // will be inserted. The first two OpFMul instructions will result the
254       // first OpFAdd instruction to be inserted. For each remaining OpFMul, 1
255       // OpFAdd will be inserted. The last OpFAdd instruction is got by changing
256       // the OpDot instruction.
257       return 4 * ir_context->get_type_mgr()
258                      ->GetType(
259                          ir_context->get_def_use_mgr()
260                              ->GetDef(instruction->GetSingleWordInOperand(0))
261                              ->type_id())
262                      ->AsVector()
263                      ->element_count() -
264              2;
265     default:
266       assert(false && "Unsupported linear algebra instruction.");
267       return 0;
268   }
269 }
270 
ReplaceOpTranspose(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const271 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpTranspose(
272     opt::IRContext* ir_context,
273     opt::Instruction* linear_algebra_instruction) const {
274   // Gets OpTranspose instruction information.
275   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
276       linear_algebra_instruction->GetSingleWordInOperand(0));
277   uint32_t matrix_column_count = ir_context->get_type_mgr()
278                                      ->GetType(matrix_instruction->type_id())
279                                      ->AsMatrix()
280                                      ->element_count();
281   auto matrix_column_type = ir_context->get_type_mgr()
282                                 ->GetType(matrix_instruction->type_id())
283                                 ->AsMatrix()
284                                 ->element_type();
285   auto matrix_column_component_type =
286       matrix_column_type->AsVector()->element_type();
287   uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
288   auto resulting_matrix_column_type =
289       ir_context->get_type_mgr()
290           ->GetType(linear_algebra_instruction->type_id())
291           ->AsMatrix()
292           ->element_type();
293 
294   uint32_t fresh_id_index = 0;
295   std::vector<uint32_t> result_column_ids(matrix_row_count);
296   for (uint32_t i = 0; i < matrix_row_count; i++) {
297     std::vector<uint32_t> column_component_ids(matrix_column_count);
298     for (uint32_t j = 0; j < matrix_column_count; j++) {
299       // Extracts the matrix column.
300       uint32_t matrix_column_id = message_.fresh_ids(fresh_id_index++);
301       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
302           ir_context, spv::Op::OpCompositeExtract,
303           ir_context->get_type_mgr()->GetId(matrix_column_type),
304           matrix_column_id,
305           opt::Instruction::OperandList(
306               {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
307                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
308 
309       // Extracts the matrix column component.
310       column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
311       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
312           ir_context, spv::Op::OpCompositeExtract,
313           ir_context->get_type_mgr()->GetId(matrix_column_component_type),
314           column_component_ids[j],
315           opt::Instruction::OperandList(
316               {{SPV_OPERAND_TYPE_ID, {matrix_column_id}},
317                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
318     }
319 
320     // Inserts the resulting matrix column.
321     opt::Instruction::OperandList in_operands;
322     for (auto& column_component_id : column_component_ids) {
323       in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
324     }
325     result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
326     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
327         ir_context, spv::Op::OpCompositeConstruct,
328         ir_context->get_type_mgr()->GetId(resulting_matrix_column_type),
329         result_column_ids[i], opt::Instruction::OperandList(in_operands)));
330   }
331 
332   // The OpTranspose instruction is changed to an OpCompositeConstruct
333   // instruction.
334   linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
335   linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
336   for (uint32_t i = 1; i < result_column_ids.size(); i++) {
337     linear_algebra_instruction->AddOperand(
338         {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
339   }
340 
341   fuzzerutil::UpdateModuleIdBound(
342       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
343 }
344 
ReplaceOpVectorTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const345 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesScalar(
346     opt::IRContext* ir_context,
347     opt::Instruction* linear_algebra_instruction) const {
348   // Gets OpVectorTimesScalar in operands.
349   auto vector = ir_context->get_def_use_mgr()->GetDef(
350       linear_algebra_instruction->GetSingleWordInOperand(0));
351   auto scalar = ir_context->get_def_use_mgr()->GetDef(
352       linear_algebra_instruction->GetSingleWordInOperand(1));
353 
354   uint32_t vector_component_count = ir_context->get_type_mgr()
355                                         ->GetType(vector->type_id())
356                                         ->AsVector()
357                                         ->element_count();
358   std::vector<uint32_t> float_multiplication_ids(vector_component_count);
359   uint32_t fresh_id_index = 0;
360 
361   for (uint32_t i = 0; i < vector_component_count; i++) {
362     // Extracts |vector| component.
363     uint32_t vector_extract_id = message_.fresh_ids(fresh_id_index++);
364     fuzzerutil::UpdateModuleIdBound(ir_context, vector_extract_id);
365     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
366         ir_context, spv::Op::OpCompositeExtract, scalar->type_id(),
367         vector_extract_id,
368         opt::Instruction::OperandList(
369             {{SPV_OPERAND_TYPE_ID, {vector->result_id()}},
370              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
371 
372     // Multiplies the |vector| component with the |scalar|.
373     uint32_t float_multiplication_id = message_.fresh_ids(fresh_id_index++);
374     float_multiplication_ids[i] = float_multiplication_id;
375     fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_id);
376     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
377         ir_context, spv::Op::OpFMul, scalar->type_id(), float_multiplication_id,
378         opt::Instruction::OperandList(
379             {{SPV_OPERAND_TYPE_ID, {vector_extract_id}},
380              {SPV_OPERAND_TYPE_ID, {scalar->result_id()}}})));
381   }
382 
383   // The OpVectorTimesScalar instruction is changed to an OpCompositeConstruct
384   // instruction.
385   linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
386   linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
387   linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
388   for (uint32_t i = 2; i < float_multiplication_ids.size(); i++) {
389     linear_algebra_instruction->AddOperand(
390         {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}});
391   }
392 }
393 
ReplaceOpMatrixTimesScalar(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const394 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesScalar(
395     opt::IRContext* ir_context,
396     opt::Instruction* linear_algebra_instruction) const {
397   // Gets OpMatrixTimesScalar in operands.
398   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
399       linear_algebra_instruction->GetSingleWordInOperand(0));
400   auto scalar_instruction = ir_context->get_def_use_mgr()->GetDef(
401       linear_algebra_instruction->GetSingleWordInOperand(1));
402 
403   // Gets matrix information.
404   uint32_t matrix_column_count = ir_context->get_type_mgr()
405                                      ->GetType(matrix_instruction->type_id())
406                                      ->AsMatrix()
407                                      ->element_count();
408   auto matrix_column_type = ir_context->get_type_mgr()
409                                 ->GetType(matrix_instruction->type_id())
410                                 ->AsMatrix()
411                                 ->element_type();
412   uint32_t matrix_column_size = matrix_column_type->AsVector()->element_count();
413 
414   std::vector<uint32_t> composite_construct_ids(matrix_column_count);
415   uint32_t fresh_id_index = 0;
416 
417   for (uint32_t i = 0; i < matrix_column_count; i++) {
418     // Extracts |matrix| column.
419     uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
420     fuzzerutil::UpdateModuleIdBound(ir_context, matrix_extract_id);
421     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
422         ir_context, spv::Op::OpCompositeExtract,
423         ir_context->get_type_mgr()->GetId(matrix_column_type),
424         matrix_extract_id,
425         opt::Instruction::OperandList(
426             {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
427              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
428 
429     std::vector<uint32_t> float_multiplication_ids(matrix_column_size);
430 
431     for (uint32_t j = 0; j < matrix_column_size; j++) {
432       // Extracts |column| component.
433       uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
434       fuzzerutil::UpdateModuleIdBound(ir_context, column_extract_id);
435       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
436           ir_context, spv::Op::OpCompositeExtract,
437           scalar_instruction->type_id(), column_extract_id,
438           opt::Instruction::OperandList(
439               {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
440                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
441 
442       // Multiplies the |column| component with the |scalar|.
443       float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
444       fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[j]);
445       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
446           ir_context, spv::Op::OpFMul, scalar_instruction->type_id(),
447           float_multiplication_ids[j],
448           opt::Instruction::OperandList(
449               {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
450                {SPV_OPERAND_TYPE_ID, {scalar_instruction->result_id()}}})));
451     }
452 
453     // Constructs a new column multiplied by |scalar|.
454     opt::Instruction::OperandList composite_construct_in_operands;
455     for (uint32_t& float_multiplication_id : float_multiplication_ids) {
456       composite_construct_in_operands.push_back(
457           {SPV_OPERAND_TYPE_ID, {float_multiplication_id}});
458     }
459     composite_construct_ids[i] = message_.fresh_ids(fresh_id_index++);
460     fuzzerutil::UpdateModuleIdBound(ir_context, composite_construct_ids[i]);
461     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
462         ir_context, spv::Op::OpCompositeConstruct,
463         ir_context->get_type_mgr()->GetId(matrix_column_type),
464         composite_construct_ids[i], composite_construct_in_operands));
465   }
466 
467   // The OpMatrixTimesScalar instruction is changed to an OpCompositeConstruct
468   // instruction.
469   linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
470   linear_algebra_instruction->SetInOperand(0, {composite_construct_ids[0]});
471   linear_algebra_instruction->SetInOperand(1, {composite_construct_ids[1]});
472   for (uint32_t i = 2; i < composite_construct_ids.size(); i++) {
473     linear_algebra_instruction->AddOperand(
474         {SPV_OPERAND_TYPE_ID, {composite_construct_ids[i]}});
475   }
476 }
477 
ReplaceOpVectorTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const478 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpVectorTimesMatrix(
479     opt::IRContext* ir_context,
480     opt::Instruction* linear_algebra_instruction) const {
481   // Gets vector information.
482   auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
483       linear_algebra_instruction->GetSingleWordInOperand(0));
484   uint32_t vector_component_count = ir_context->get_type_mgr()
485                                         ->GetType(vector_instruction->type_id())
486                                         ->AsVector()
487                                         ->element_count();
488   auto vector_component_type = ir_context->get_type_mgr()
489                                    ->GetType(vector_instruction->type_id())
490                                    ->AsVector()
491                                    ->element_type();
492 
493   // Extracts vector components.
494   uint32_t fresh_id_index = 0;
495   std::vector<uint32_t> vector_component_ids(vector_component_count);
496   for (uint32_t i = 0; i < vector_component_count; i++) {
497     vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
498     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
499         ir_context, spv::Op::OpCompositeExtract,
500         ir_context->get_type_mgr()->GetId(vector_component_type),
501         vector_component_ids[i],
502         opt::Instruction::OperandList(
503             {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
504              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
505   }
506 
507   // Gets matrix information.
508   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
509       linear_algebra_instruction->GetSingleWordInOperand(1));
510   uint32_t matrix_column_count = ir_context->get_type_mgr()
511                                      ->GetType(matrix_instruction->type_id())
512                                      ->AsMatrix()
513                                      ->element_count();
514   auto matrix_column_type = ir_context->get_type_mgr()
515                                 ->GetType(matrix_instruction->type_id())
516                                 ->AsMatrix()
517                                 ->element_type();
518 
519   std::vector<uint32_t> result_component_ids(matrix_column_count);
520   for (uint32_t i = 0; i < matrix_column_count; i++) {
521     // Extracts matrix column.
522     uint32_t matrix_extract_id = message_.fresh_ids(fresh_id_index++);
523     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
524         ir_context, spv::Op::OpCompositeExtract,
525         ir_context->get_type_mgr()->GetId(matrix_column_type),
526         matrix_extract_id,
527         opt::Instruction::OperandList(
528             {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
529              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
530 
531     std::vector<uint32_t> float_multiplication_ids(vector_component_count);
532     for (uint32_t j = 0; j < vector_component_count; j++) {
533       // Extracts column component.
534       uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
535       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
536           ir_context, spv::Op::OpCompositeExtract,
537           ir_context->get_type_mgr()->GetId(vector_component_type),
538           column_extract_id,
539           opt::Instruction::OperandList(
540               {{SPV_OPERAND_TYPE_ID, {matrix_extract_id}},
541                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
542 
543       // Multiplies corresponding vector and column components.
544       float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
545       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
546           ir_context, spv::Op::OpFMul,
547           ir_context->get_type_mgr()->GetId(vector_component_type),
548           float_multiplication_ids[j],
549           opt::Instruction::OperandList(
550               {{SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}},
551                {SPV_OPERAND_TYPE_ID, {column_extract_id}}})));
552     }
553 
554     // Adds the multiplication results.
555     std::vector<uint32_t> float_add_ids;
556     uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
557     float_add_ids.push_back(float_add_id);
558     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
559         ir_context, spv::Op::OpFAdd,
560         ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
561         opt::Instruction::OperandList(
562             {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
563              {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
564     for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
565       float_add_id = message_.fresh_ids(fresh_id_index++);
566       float_add_ids.push_back(float_add_id);
567       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
568           ir_context, spv::Op::OpFAdd,
569           ir_context->get_type_mgr()->GetId(vector_component_type),
570           float_add_id,
571           opt::Instruction::OperandList(
572               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
573                {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
574     }
575 
576     result_component_ids[i] = float_add_ids.back();
577   }
578 
579   // The OpVectorTimesMatrix instruction is changed to an OpCompositeConstruct
580   // instruction.
581   linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
582   linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
583   linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
584   for (uint32_t i = 2; i < result_component_ids.size(); i++) {
585     linear_algebra_instruction->AddOperand(
586         {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
587   }
588 
589   fuzzerutil::UpdateModuleIdBound(
590       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
591 }
592 
ReplaceOpMatrixTimesVector(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const593 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesVector(
594     opt::IRContext* ir_context,
595     opt::Instruction* linear_algebra_instruction) const {
596   // Gets matrix information.
597   auto matrix_instruction = ir_context->get_def_use_mgr()->GetDef(
598       linear_algebra_instruction->GetSingleWordInOperand(0));
599   uint32_t matrix_column_count = ir_context->get_type_mgr()
600                                      ->GetType(matrix_instruction->type_id())
601                                      ->AsMatrix()
602                                      ->element_count();
603   auto matrix_column_type = ir_context->get_type_mgr()
604                                 ->GetType(matrix_instruction->type_id())
605                                 ->AsMatrix()
606                                 ->element_type();
607   uint32_t matrix_row_count = matrix_column_type->AsVector()->element_count();
608 
609   // Extracts matrix columns.
610   uint32_t fresh_id_index = 0;
611   std::vector<uint32_t> matrix_column_ids(matrix_column_count);
612   for (uint32_t i = 0; i < matrix_column_count; i++) {
613     matrix_column_ids[i] = message_.fresh_ids(fresh_id_index++);
614     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
615         ir_context, spv::Op::OpCompositeExtract,
616         ir_context->get_type_mgr()->GetId(matrix_column_type),
617         matrix_column_ids[i],
618         opt::Instruction::OperandList(
619             {{SPV_OPERAND_TYPE_ID, {matrix_instruction->result_id()}},
620              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
621   }
622 
623   // Gets vector information.
624   auto vector_instruction = ir_context->get_def_use_mgr()->GetDef(
625       linear_algebra_instruction->GetSingleWordInOperand(1));
626   auto vector_component_type = ir_context->get_type_mgr()
627                                    ->GetType(vector_instruction->type_id())
628                                    ->AsVector()
629                                    ->element_type();
630 
631   // Extracts vector components.
632   std::vector<uint32_t> vector_component_ids(matrix_column_count);
633   for (uint32_t i = 0; i < matrix_column_count; i++) {
634     vector_component_ids[i] = message_.fresh_ids(fresh_id_index++);
635     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
636         ir_context, spv::Op::OpCompositeExtract,
637         ir_context->get_type_mgr()->GetId(vector_component_type),
638         vector_component_ids[i],
639         opt::Instruction::OperandList(
640             {{SPV_OPERAND_TYPE_ID, {vector_instruction->result_id()}},
641              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
642   }
643 
644   std::vector<uint32_t> result_component_ids(matrix_row_count);
645   for (uint32_t i = 0; i < matrix_row_count; i++) {
646     std::vector<uint32_t> float_multiplication_ids(matrix_column_count);
647     for (uint32_t j = 0; j < matrix_column_count; j++) {
648       // Extracts column component.
649       uint32_t column_extract_id = message_.fresh_ids(fresh_id_index++);
650       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
651           ir_context, spv::Op::OpCompositeExtract,
652           ir_context->get_type_mgr()->GetId(vector_component_type),
653           column_extract_id,
654           opt::Instruction::OperandList(
655               {{SPV_OPERAND_TYPE_ID, {matrix_column_ids[j]}},
656                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
657 
658       // Multiplies corresponding vector and column components.
659       float_multiplication_ids[j] = message_.fresh_ids(fresh_id_index++);
660       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
661           ir_context, spv::Op::OpFMul,
662           ir_context->get_type_mgr()->GetId(vector_component_type),
663           float_multiplication_ids[j],
664           opt::Instruction::OperandList(
665               {{SPV_OPERAND_TYPE_ID, {column_extract_id}},
666                {SPV_OPERAND_TYPE_ID, {vector_component_ids[j]}}})));
667     }
668 
669     // Adds the multiplication results.
670     std::vector<uint32_t> float_add_ids;
671     uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
672     float_add_ids.push_back(float_add_id);
673     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
674         ir_context, spv::Op::OpFAdd,
675         ir_context->get_type_mgr()->GetId(vector_component_type), float_add_id,
676         opt::Instruction::OperandList(
677             {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
678              {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
679     for (uint32_t j = 2; j < float_multiplication_ids.size(); j++) {
680       float_add_id = message_.fresh_ids(fresh_id_index++);
681       float_add_ids.push_back(float_add_id);
682       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
683           ir_context, spv::Op::OpFAdd,
684           ir_context->get_type_mgr()->GetId(vector_component_type),
685           float_add_id,
686           opt::Instruction::OperandList(
687               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[j]}},
688                {SPV_OPERAND_TYPE_ID, {float_add_ids[j - 2]}}})));
689     }
690 
691     result_component_ids[i] = float_add_ids.back();
692   }
693 
694   // The OpMatrixTimesVector instruction is changed to an OpCompositeConstruct
695   // instruction.
696   linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
697   linear_algebra_instruction->SetInOperand(0, {result_component_ids[0]});
698   linear_algebra_instruction->SetInOperand(1, {result_component_ids[1]});
699   for (uint32_t i = 2; i < result_component_ids.size(); i++) {
700     linear_algebra_instruction->AddOperand(
701         {SPV_OPERAND_TYPE_ID, {result_component_ids[i]}});
702   }
703 
704   fuzzerutil::UpdateModuleIdBound(
705       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
706 }
707 
ReplaceOpMatrixTimesMatrix(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const708 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpMatrixTimesMatrix(
709     opt::IRContext* ir_context,
710     opt::Instruction* linear_algebra_instruction) const {
711   // Gets matrix 1 information.
712   auto matrix_1_instruction = ir_context->get_def_use_mgr()->GetDef(
713       linear_algebra_instruction->GetSingleWordInOperand(0));
714   uint32_t matrix_1_column_count =
715       ir_context->get_type_mgr()
716           ->GetType(matrix_1_instruction->type_id())
717           ->AsMatrix()
718           ->element_count();
719   auto matrix_1_column_type = ir_context->get_type_mgr()
720                                   ->GetType(matrix_1_instruction->type_id())
721                                   ->AsMatrix()
722                                   ->element_type();
723   auto matrix_1_column_component_type =
724       matrix_1_column_type->AsVector()->element_type();
725   uint32_t matrix_1_row_count =
726       matrix_1_column_type->AsVector()->element_count();
727 
728   // Gets matrix 2 information.
729   auto matrix_2_instruction = ir_context->get_def_use_mgr()->GetDef(
730       linear_algebra_instruction->GetSingleWordInOperand(1));
731   uint32_t matrix_2_column_count =
732       ir_context->get_type_mgr()
733           ->GetType(matrix_2_instruction->type_id())
734           ->AsMatrix()
735           ->element_count();
736   auto matrix_2_column_type = ir_context->get_type_mgr()
737                                   ->GetType(matrix_2_instruction->type_id())
738                                   ->AsMatrix()
739                                   ->element_type();
740 
741   uint32_t fresh_id_index = 0;
742   std::vector<uint32_t> result_column_ids(matrix_2_column_count);
743   for (uint32_t i = 0; i < matrix_2_column_count; i++) {
744     // Extracts matrix 2 column.
745     uint32_t matrix_2_column_id = message_.fresh_ids(fresh_id_index++);
746     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
747         ir_context, spv::Op::OpCompositeExtract,
748         ir_context->get_type_mgr()->GetId(matrix_2_column_type),
749         matrix_2_column_id,
750         opt::Instruction::OperandList(
751             {{SPV_OPERAND_TYPE_ID, {matrix_2_instruction->result_id()}},
752              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
753 
754     std::vector<uint32_t> column_component_ids(matrix_1_row_count);
755     for (uint32_t j = 0; j < matrix_1_row_count; j++) {
756       std::vector<uint32_t> float_multiplication_ids(matrix_1_column_count);
757       for (uint32_t k = 0; k < matrix_1_column_count; k++) {
758         // Extracts matrix 1 column.
759         uint32_t matrix_1_column_id = message_.fresh_ids(fresh_id_index++);
760         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
761             ir_context, spv::Op::OpCompositeExtract,
762             ir_context->get_type_mgr()->GetId(matrix_1_column_type),
763             matrix_1_column_id,
764             opt::Instruction::OperandList(
765                 {{SPV_OPERAND_TYPE_ID, {matrix_1_instruction->result_id()}},
766                  {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
767 
768         // Extracts matrix 1 column component.
769         uint32_t matrix_1_column_component_id =
770             message_.fresh_ids(fresh_id_index++);
771         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
772             ir_context, spv::Op::OpCompositeExtract,
773             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
774             matrix_1_column_component_id,
775             opt::Instruction::OperandList(
776                 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_id}},
777                  {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
778 
779         // Extracts matrix 2 column component.
780         uint32_t matrix_2_column_component_id =
781             message_.fresh_ids(fresh_id_index++);
782         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
783             ir_context, spv::Op::OpCompositeExtract,
784             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
785             matrix_2_column_component_id,
786             opt::Instruction::OperandList(
787                 {{SPV_OPERAND_TYPE_ID, {matrix_2_column_id}},
788                  {SPV_OPERAND_TYPE_LITERAL_INTEGER, {k}}})));
789 
790         // Multiplies corresponding matrix 1 and matrix 2 column components.
791         float_multiplication_ids[k] = message_.fresh_ids(fresh_id_index++);
792         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
793             ir_context, spv::Op::OpFMul,
794             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
795             float_multiplication_ids[k],
796             opt::Instruction::OperandList(
797                 {{SPV_OPERAND_TYPE_ID, {matrix_1_column_component_id}},
798                  {SPV_OPERAND_TYPE_ID, {matrix_2_column_component_id}}})));
799       }
800 
801       // Adds the multiplication results.
802       std::vector<uint32_t> float_add_ids;
803       uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
804       float_add_ids.push_back(float_add_id);
805       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
806           ir_context, spv::Op::OpFAdd,
807           ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
808           float_add_id,
809           opt::Instruction::OperandList(
810               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
811                {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
812       for (uint32_t k = 2; k < float_multiplication_ids.size(); k++) {
813         float_add_id = message_.fresh_ids(fresh_id_index++);
814         float_add_ids.push_back(float_add_id);
815         linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
816             ir_context, spv::Op::OpFAdd,
817             ir_context->get_type_mgr()->GetId(matrix_1_column_component_type),
818             float_add_id,
819             opt::Instruction::OperandList(
820                 {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[k]}},
821                  {SPV_OPERAND_TYPE_ID, {float_add_ids[k - 2]}}})));
822       }
823 
824       column_component_ids[j] = float_add_ids.back();
825     }
826 
827     // Inserts the resulting matrix column.
828     opt::Instruction::OperandList in_operands;
829     for (auto& column_component_id : column_component_ids) {
830       in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
831     }
832     result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
833     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
834         ir_context, spv::Op::OpCompositeConstruct,
835         ir_context->get_type_mgr()->GetId(matrix_1_column_type),
836         result_column_ids[i], opt::Instruction::OperandList(in_operands)));
837   }
838 
839   // The OpMatrixTimesMatrix instruction is changed to an OpCompositeConstruct
840   // instruction.
841   linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
842   linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
843   linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
844   for (uint32_t i = 2; i < result_column_ids.size(); i++) {
845     linear_algebra_instruction->AddOperand(
846         {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
847   }
848 
849   fuzzerutil::UpdateModuleIdBound(
850       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
851 }
852 
ReplaceOpOuterProduct(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const853 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpOuterProduct(
854     opt::IRContext* ir_context,
855     opt::Instruction* linear_algebra_instruction) const {
856   // Gets vector 1 information.
857   auto vector_1_instruction = ir_context->get_def_use_mgr()->GetDef(
858       linear_algebra_instruction->GetSingleWordInOperand(0));
859   uint32_t vector_1_component_count =
860       ir_context->get_type_mgr()
861           ->GetType(vector_1_instruction->type_id())
862           ->AsVector()
863           ->element_count();
864   auto vector_1_component_type = ir_context->get_type_mgr()
865                                      ->GetType(vector_1_instruction->type_id())
866                                      ->AsVector()
867                                      ->element_type();
868 
869   // Gets vector 2 information.
870   auto vector_2_instruction = ir_context->get_def_use_mgr()->GetDef(
871       linear_algebra_instruction->GetSingleWordInOperand(1));
872   uint32_t vector_2_component_count =
873       ir_context->get_type_mgr()
874           ->GetType(vector_2_instruction->type_id())
875           ->AsVector()
876           ->element_count();
877 
878   uint32_t fresh_id_index = 0;
879   std::vector<uint32_t> result_column_ids(vector_2_component_count);
880   for (uint32_t i = 0; i < vector_2_component_count; i++) {
881     // Extracts |vector_2| component.
882     uint32_t vector_2_component_id = message_.fresh_ids(fresh_id_index++);
883     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
884         ir_context, spv::Op::OpCompositeExtract,
885         ir_context->get_type_mgr()->GetId(vector_1_component_type),
886         vector_2_component_id,
887         opt::Instruction::OperandList(
888             {{SPV_OPERAND_TYPE_ID, {vector_2_instruction->result_id()}},
889              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
890 
891     std::vector<uint32_t> column_component_ids(vector_1_component_count);
892     for (uint32_t j = 0; j < vector_1_component_count; j++) {
893       // Extracts |vector_1| component.
894       uint32_t vector_1_component_id = message_.fresh_ids(fresh_id_index++);
895       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
896           ir_context, spv::Op::OpCompositeExtract,
897           ir_context->get_type_mgr()->GetId(vector_1_component_type),
898           vector_1_component_id,
899           opt::Instruction::OperandList(
900               {{SPV_OPERAND_TYPE_ID, {vector_1_instruction->result_id()}},
901                {SPV_OPERAND_TYPE_LITERAL_INTEGER, {j}}})));
902 
903       // Multiplies |vector_1| and |vector_2| components.
904       column_component_ids[j] = message_.fresh_ids(fresh_id_index++);
905       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
906           ir_context, spv::Op::OpFMul,
907           ir_context->get_type_mgr()->GetId(vector_1_component_type),
908           column_component_ids[j],
909           opt::Instruction::OperandList(
910               {{SPV_OPERAND_TYPE_ID, {vector_2_component_id}},
911                {SPV_OPERAND_TYPE_ID, {vector_1_component_id}}})));
912     }
913 
914     // Inserts the resulting matrix column.
915     opt::Instruction::OperandList in_operands;
916     for (auto& column_component_id : column_component_ids) {
917       in_operands.push_back({SPV_OPERAND_TYPE_ID, {column_component_id}});
918     }
919     result_column_ids[i] = message_.fresh_ids(fresh_id_index++);
920     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
921         ir_context, spv::Op::OpCompositeConstruct,
922         vector_1_instruction->type_id(), result_column_ids[i], in_operands));
923   }
924 
925   // The OpOuterProduct instruction is changed to an OpCompositeConstruct
926   // instruction.
927   linear_algebra_instruction->SetOpcode(spv::Op::OpCompositeConstruct);
928   linear_algebra_instruction->SetInOperand(0, {result_column_ids[0]});
929   linear_algebra_instruction->SetInOperand(1, {result_column_ids[1]});
930   for (uint32_t i = 2; i < result_column_ids.size(); i++) {
931     linear_algebra_instruction->AddOperand(
932         {SPV_OPERAND_TYPE_ID, {result_column_ids[i]}});
933   }
934 
935   fuzzerutil::UpdateModuleIdBound(
936       ir_context, message_.fresh_ids(message_.fresh_ids().size() - 1));
937 }
938 
ReplaceOpDot(opt::IRContext * ir_context,opt::Instruction * linear_algebra_instruction) const939 void TransformationReplaceLinearAlgebraInstruction::ReplaceOpDot(
940     opt::IRContext* ir_context,
941     opt::Instruction* linear_algebra_instruction) const {
942   // Gets OpDot in operands.
943   auto vector_1 = ir_context->get_def_use_mgr()->GetDef(
944       linear_algebra_instruction->GetSingleWordInOperand(0));
945   auto vector_2 = ir_context->get_def_use_mgr()->GetDef(
946       linear_algebra_instruction->GetSingleWordInOperand(1));
947 
948   uint32_t vectors_component_count = ir_context->get_type_mgr()
949                                          ->GetType(vector_1->type_id())
950                                          ->AsVector()
951                                          ->element_count();
952   std::vector<uint32_t> float_multiplication_ids(vectors_component_count);
953   uint32_t fresh_id_index = 0;
954 
955   for (uint32_t i = 0; i < vectors_component_count; i++) {
956     // Extracts |vector_1| component.
957     uint32_t vector_1_extract_id = message_.fresh_ids(fresh_id_index++);
958     fuzzerutil::UpdateModuleIdBound(ir_context, vector_1_extract_id);
959     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
960         ir_context, spv::Op::OpCompositeExtract,
961         linear_algebra_instruction->type_id(), vector_1_extract_id,
962         opt::Instruction::OperandList(
963             {{SPV_OPERAND_TYPE_ID, {vector_1->result_id()}},
964              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
965 
966     // Extracts |vector_2| component.
967     uint32_t vector_2_extract_id = message_.fresh_ids(fresh_id_index++);
968     fuzzerutil::UpdateModuleIdBound(ir_context, vector_2_extract_id);
969     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
970         ir_context, spv::Op::OpCompositeExtract,
971         linear_algebra_instruction->type_id(), vector_2_extract_id,
972         opt::Instruction::OperandList(
973             {{SPV_OPERAND_TYPE_ID, {vector_2->result_id()}},
974              {SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}})));
975 
976     // Multiplies the pair of components.
977     float_multiplication_ids[i] = message_.fresh_ids(fresh_id_index++);
978     fuzzerutil::UpdateModuleIdBound(ir_context, float_multiplication_ids[i]);
979     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
980         ir_context, spv::Op::OpFMul, linear_algebra_instruction->type_id(),
981         float_multiplication_ids[i],
982         opt::Instruction::OperandList(
983             {{SPV_OPERAND_TYPE_ID, {vector_1_extract_id}},
984              {SPV_OPERAND_TYPE_ID, {vector_2_extract_id}}})));
985   }
986 
987   // If the vector has 2 components, then there will be 2 float multiplication
988   // instructions.
989   if (vectors_component_count == 2) {
990     linear_algebra_instruction->SetOpcode(spv::Op::OpFAdd);
991     linear_algebra_instruction->SetInOperand(0, {float_multiplication_ids[0]});
992     linear_algebra_instruction->SetInOperand(1, {float_multiplication_ids[1]});
993   } else {
994     // The first OpFAdd instruction has as operands the first two OpFMul
995     // instructions.
996     std::vector<uint32_t> float_add_ids;
997     uint32_t float_add_id = message_.fresh_ids(fresh_id_index++);
998     float_add_ids.push_back(float_add_id);
999     fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
1000     linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1001         ir_context, spv::Op::OpFAdd, linear_algebra_instruction->type_id(),
1002         float_add_id,
1003         opt::Instruction::OperandList(
1004             {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[0]}},
1005              {SPV_OPERAND_TYPE_ID, {float_multiplication_ids[1]}}})));
1006 
1007     // The remaining OpFAdd instructions has as operands an OpFMul and an OpFAdd
1008     // instruction.
1009     for (uint32_t i = 2; i < float_multiplication_ids.size() - 1; i++) {
1010       float_add_id = message_.fresh_ids(fresh_id_index++);
1011       fuzzerutil::UpdateModuleIdBound(ir_context, float_add_id);
1012       float_add_ids.push_back(float_add_id);
1013       linear_algebra_instruction->InsertBefore(MakeUnique<opt::Instruction>(
1014           ir_context, spv::Op::OpFAdd, linear_algebra_instruction->type_id(),
1015           float_add_id,
1016           opt::Instruction::OperandList(
1017               {{SPV_OPERAND_TYPE_ID, {float_multiplication_ids[i]}},
1018                {SPV_OPERAND_TYPE_ID, {float_add_ids[i - 2]}}})));
1019     }
1020 
1021     // The last OpFAdd instruction is got by changing some of the OpDot
1022     // instruction attributes.
1023     linear_algebra_instruction->SetOpcode(spv::Op::OpFAdd);
1024     linear_algebra_instruction->SetInOperand(
1025         0, {float_multiplication_ids[float_multiplication_ids.size() - 1]});
1026     linear_algebra_instruction->SetInOperand(
1027         1, {float_add_ids[float_add_ids.size() - 1]});
1028   }
1029 }
1030 
1031 std::unordered_set<uint32_t>
GetFreshIds() const1032 TransformationReplaceLinearAlgebraInstruction::GetFreshIds() const {
1033   std::unordered_set<uint32_t> result;
1034   for (auto id : message_.fresh_ids()) {
1035     result.insert(id);
1036   }
1037   return result;
1038 }
1039 
1040 }  // namespace fuzz
1041 }  // namespace spvtools
1042