1 // Copyright (c) 2020 André Perez Maselco 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_ 16 #define SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_ 17 18 #include "source/fuzz/protobufs/spirvfuzz_protobufs.h" 19 #include "source/fuzz/transformation.h" 20 #include "source/fuzz/transformation_context.h" 21 #include "source/opt/ir_context.h" 22 23 namespace spvtools { 24 namespace fuzz { 25 26 class TransformationReplaceLinearAlgebraInstruction : public Transformation { 27 public: 28 explicit TransformationReplaceLinearAlgebraInstruction( 29 protobufs::TransformationReplaceLinearAlgebraInstruction message); 30 31 TransformationReplaceLinearAlgebraInstruction( 32 const std::vector<uint32_t>& fresh_ids, 33 const protobufs::InstructionDescriptor& instruction_descriptor); 34 35 // - |message_.fresh_ids| must be fresh ids needed to apply the 36 // transformation. 37 // - |message_.instruction_descriptor| must be a linear algebra instruction 38 bool IsApplicable( 39 opt::IRContext* ir_context, 40 const TransformationContext& transformation_context) const override; 41 42 // Replaces a linear algebra instruction. 43 void Apply(opt::IRContext* ir_context, 44 TransformationContext* transformation_context) const override; 45 46 std::unordered_set<uint32_t> GetFreshIds() const override; 47 48 protobufs::Transformation ToMessage() const override; 49 50 // Returns the number of ids needed to apply the transformation. 51 static uint32_t GetRequiredFreshIdCount(opt::IRContext* ir_context, 52 opt::Instruction* instruction); 53 54 private: 55 protobufs::TransformationReplaceLinearAlgebraInstruction message_; 56 57 // Replaces an OpTranspose instruction. 58 void ReplaceOpTranspose(opt::IRContext* ir_context, 59 opt::Instruction* instruction) const; 60 61 // Replaces an OpVectorTimesScalar instruction. 62 void ReplaceOpVectorTimesScalar(opt::IRContext* ir_context, 63 opt::Instruction* instruction) const; 64 65 // Replaces an OpMatrixTimesScalar instruction. 66 void ReplaceOpMatrixTimesScalar(opt::IRContext* ir_context, 67 opt::Instruction* instruction) const; 68 69 // Replaces an OpVectorTimesMatrix instruction. 70 void ReplaceOpVectorTimesMatrix(opt::IRContext* ir_context, 71 opt::Instruction* instruction) const; 72 73 // Replaces an OpMatrixTimesVector instruction. 74 void ReplaceOpMatrixTimesVector(opt::IRContext* ir_context, 75 opt::Instruction* instruction) const; 76 77 // Replaces an OpMatrixTimesMatrix instruction. 78 void ReplaceOpMatrixTimesMatrix(opt::IRContext* ir_context, 79 opt::Instruction* instruction) const; 80 81 // Replaces an OpOuterProduct instruction. 82 void ReplaceOpOuterProduct(opt::IRContext* ir_context, 83 opt::Instruction* instruction) const; 84 85 // Replaces an OpDot instruction. 86 void ReplaceOpDot(opt::IRContext* ir_context, 87 opt::Instruction* instruction) const; 88 }; 89 90 } // namespace fuzz 91 } // namespace spvtools 92 93 #endif // SOURCE_FUZZ_TRANSFORMATION_REPLACE_LINEAR_ALGEBRA_INSTRUCTION_H_ 94