• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_
16 #define SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_
17 
18 #include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
19 #include "source/fuzz/transformation.h"
20 #include "source/fuzz/transformation_context.h"
21 #include "source/opt/ir_context.h"
22 
23 namespace spvtools {
24 namespace fuzz {
25 
26 class TransformationReplaceLinearAlgebraInstruction : public Transformation {
27  public:
28   explicit TransformationReplaceLinearAlgebraInstruction(
29       protobufs::TransformationReplaceLinearAlgebraInstruction message);
30 
31   TransformationReplaceLinearAlgebraInstruction(
32       const std::vector<uint32_t>& fresh_ids,
33       const protobufs::InstructionDescriptor& instruction_descriptor);
34 
35   // - |message_.fresh_ids| must be fresh ids needed to apply the
36   // transformation.
37   // - |message_.instruction_descriptor| must be a linear algebra instruction
38   bool IsApplicable(
39       opt::IRContext* ir_context,
40       const TransformationContext& transformation_context) const override;
41 
42   // Replaces a linear algebra instruction.
43   void Apply(opt::IRContext* ir_context,
44              TransformationContext* transformation_context) const override;
45 
46   std::unordered_set<uint32_t> GetFreshIds() const override;
47 
48   protobufs::Transformation ToMessage() const override;
49 
50   // Returns the number of ids needed to apply the transformation.
51   static uint32_t GetRequiredFreshIdCount(opt::IRContext* ir_context,
52                                           opt::Instruction* instruction);
53 
54  private:
55   protobufs::TransformationReplaceLinearAlgebraInstruction message_;
56 
57   // Replaces an OpTranspose instruction.
58   void ReplaceOpTranspose(opt::IRContext* ir_context,
59                           opt::Instruction* instruction) const;
60 
61   // Replaces an OpVectorTimesScalar instruction.
62   void ReplaceOpVectorTimesScalar(opt::IRContext* ir_context,
63                                   opt::Instruction* instruction) const;
64 
65   // Replaces an OpMatrixTimesScalar instruction.
66   void ReplaceOpMatrixTimesScalar(opt::IRContext* ir_context,
67                                   opt::Instruction* instruction) const;
68 
69   // Replaces an OpVectorTimesMatrix instruction.
70   void ReplaceOpVectorTimesMatrix(opt::IRContext* ir_context,
71                                   opt::Instruction* instruction) const;
72 
73   // Replaces an OpMatrixTimesVector instruction.
74   void ReplaceOpMatrixTimesVector(opt::IRContext* ir_context,
75                                   opt::Instruction* instruction) const;
76 
77   // Replaces an OpMatrixTimesMatrix instruction.
78   void ReplaceOpMatrixTimesMatrix(opt::IRContext* ir_context,
79                                   opt::Instruction* instruction) const;
80 
81   // Replaces an OpOuterProduct instruction.
82   void ReplaceOpOuterProduct(opt::IRContext* ir_context,
83                              opt::Instruction* instruction) const;
84 
85   // Replaces an OpDot instruction.
86   void ReplaceOpDot(opt::IRContext* ir_context,
87                     opt::Instruction* instruction) const;
88 };
89 
90 }  // namespace fuzz
91 }  // namespace spvtools
92 
93 #endif  // SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_
94