1 //===- llvm/MatrixBuilder.h - Builder to lower matrix ops -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MatrixBuilder class, which is used as a convenient way 10 // to lower matrix operations to LLVM IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_IR_MATRIXBUILDER_H 15 #define LLVM_IR_MATRIXBUILDER_H 16 17 #include "llvm/IR/Constant.h" 18 #include "llvm/IR/Constants.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/InstrTypes.h" 21 #include "llvm/IR/Instruction.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/Type.h" 24 #include "llvm/IR/Value.h" 25 #include "llvm/Support/Alignment.h" 26 27 namespace llvm { 28 29 class Function; 30 class Twine; 31 class Module; 32 33 template <class IRBuilderTy> class MatrixBuilder { 34 IRBuilderTy &B; getModule()35 Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } 36 splatScalarOperandIfNeeded(Value * LHS,Value * RHS)37 std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, 38 Value *RHS) { 39 assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && 40 "One of the operands must be a matrix (embedded in a vector)"); 41 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 42 assert(!isa<ScalableVectorType>(LHS->getType()) && 43 "LHS Assumed to be fixed width"); 44 RHS = B.CreateVectorSplat( 45 cast<VectorType>(LHS->getType())->getElementCount(), RHS, 46 "scalar.splat"); 47 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 48 assert(!isa<ScalableVectorType>(RHS->getType()) && 49 "RHS Assumed to be fixed width"); 50 LHS = B.CreateVectorSplat( 51 cast<VectorType>(RHS->getType())->getElementCount(), LHS, 52 "scalar.splat"); 53 } 54 return {LHS, RHS}; 55 } 56 57 public: MatrixBuilder(IRBuilderTy & Builder)58 MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} 59 60 /// Create a column major, strided matrix load. 61 /// \p DataPtr - Start address of the matrix read 62 /// \p Rows - Number of rows in matrix (must be a constant) 63 /// \p Columns - Number of columns in matrix (must be a constant) 64 /// \p Stride - Space between columns 65 CallInst *CreateColumnMajorLoad(Value *DataPtr, Align Alignment, 66 Value *Stride, bool IsVolatile, unsigned Rows, 67 unsigned Columns, const Twine &Name = "") { 68 69 // Deal with the pointer 70 PointerType *PtrTy = cast<PointerType>(DataPtr->getType()); 71 Type *EltTy = PtrTy->getElementType(); 72 73 auto *RetType = FixedVectorType::get(EltTy, Rows * Columns); 74 75 Value *Ops[] = {DataPtr, Stride, B.getInt1(IsVolatile), B.getInt32(Rows), 76 B.getInt32(Columns)}; 77 Type *OverloadedTypes[] = {RetType}; 78 79 Function *TheFn = Intrinsic::getDeclaration( 80 getModule(), Intrinsic::matrix_column_major_load, OverloadedTypes); 81 82 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 83 Attribute AlignAttr = 84 Attribute::getWithAlignment(Call->getContext(), Alignment); 85 Call->addAttribute(1, AlignAttr); 86 return Call; 87 } 88 89 /// Create a column major, strided matrix store. 90 /// \p Matrix - Matrix to store 91 /// \p Ptr - Pointer to write back to 92 /// \p Stride - Space between columns 93 CallInst *CreateColumnMajorStore(Value *Matrix, Value *Ptr, Align Alignment, 94 Value *Stride, bool IsVolatile, 95 unsigned Rows, unsigned Columns, 96 const Twine &Name = "") { 97 Value *Ops[] = {Matrix, Ptr, 98 Stride, B.getInt1(IsVolatile), 99 B.getInt32(Rows), B.getInt32(Columns)}; 100 Type *OverloadedTypes[] = {Matrix->getType()}; 101 102 Function *TheFn = Intrinsic::getDeclaration( 103 getModule(), Intrinsic::matrix_column_major_store, OverloadedTypes); 104 105 CallInst *Call = B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 106 Attribute AlignAttr = 107 Attribute::getWithAlignment(Call->getContext(), Alignment); 108 Call->addAttribute(2, AlignAttr); 109 return Call; 110 } 111 112 /// Create a llvm.matrix.transpose call, transposing \p Matrix with \p Rows 113 /// rows and \p Columns columns. 114 CallInst *CreateMatrixTranspose(Value *Matrix, unsigned Rows, 115 unsigned Columns, const Twine &Name = "") { 116 auto *OpType = cast<VectorType>(Matrix->getType()); 117 auto *ReturnType = 118 FixedVectorType::get(OpType->getElementType(), Rows * Columns); 119 120 Type *OverloadedTypes[] = {ReturnType}; 121 Value *Ops[] = {Matrix, B.getInt32(Rows), B.getInt32(Columns)}; 122 Function *TheFn = Intrinsic::getDeclaration( 123 getModule(), Intrinsic::matrix_transpose, OverloadedTypes); 124 125 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 126 } 127 128 /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p 129 /// RHS. 130 CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, 131 unsigned LHSColumns, unsigned RHSColumns, 132 const Twine &Name = "") { 133 auto *LHSType = cast<VectorType>(LHS->getType()); 134 auto *RHSType = cast<VectorType>(RHS->getType()); 135 136 auto *ReturnType = 137 FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); 138 139 Value *Ops[] = {LHS, RHS, B.getInt32(LHSRows), B.getInt32(LHSColumns), 140 B.getInt32(RHSColumns)}; 141 Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType}; 142 143 Function *TheFn = Intrinsic::getDeclaration( 144 getModule(), Intrinsic::matrix_multiply, OverloadedTypes); 145 return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); 146 } 147 148 /// Insert a single element \p NewVal into \p Matrix at indices (\p RowIdx, \p 149 /// ColumnIdx). CreateMatrixInsert(Value * Matrix,Value * NewVal,Value * RowIdx,Value * ColumnIdx,unsigned NumRows)150 Value *CreateMatrixInsert(Value *Matrix, Value *NewVal, Value *RowIdx, 151 Value *ColumnIdx, unsigned NumRows) { 152 return B.CreateInsertElement( 153 Matrix, NewVal, 154 B.CreateAdd(B.CreateMul(ColumnIdx, ConstantInt::get( 155 ColumnIdx->getType(), NumRows)), 156 RowIdx)); 157 } 158 159 /// Add matrixes \p LHS and \p RHS. Support both integer and floating point 160 /// matrixes. CreateAdd(Value * LHS,Value * RHS)161 Value *CreateAdd(Value *LHS, Value *RHS) { 162 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 163 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 164 assert(!isa<ScalableVectorType>(LHS->getType()) && 165 "LHS Assumed to be fixed width"); 166 RHS = B.CreateVectorSplat( 167 cast<VectorType>(LHS->getType())->getElementCount(), RHS, 168 "scalar.splat"); 169 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 170 assert(!isa<ScalableVectorType>(RHS->getType()) && 171 "RHS Assumed to be fixed width"); 172 LHS = B.CreateVectorSplat( 173 cast<VectorType>(RHS->getType())->getElementCount(), LHS, 174 "scalar.splat"); 175 } 176 177 return cast<VectorType>(LHS->getType()) 178 ->getElementType() 179 ->isFloatingPointTy() 180 ? B.CreateFAdd(LHS, RHS) 181 : B.CreateAdd(LHS, RHS); 182 } 183 184 /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating 185 /// point matrixes. CreateSub(Value * LHS,Value * RHS)186 Value *CreateSub(Value *LHS, Value *RHS) { 187 assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); 188 if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) { 189 assert(!isa<ScalableVectorType>(LHS->getType()) && 190 "LHS Assumed to be fixed width"); 191 RHS = B.CreateVectorSplat( 192 cast<VectorType>(LHS->getType())->getElementCount(), RHS, 193 "scalar.splat"); 194 } else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) { 195 assert(!isa<ScalableVectorType>(RHS->getType()) && 196 "RHS Assumed to be fixed width"); 197 LHS = B.CreateVectorSplat( 198 cast<VectorType>(RHS->getType())->getElementCount(), LHS, 199 "scalar.splat"); 200 } 201 202 return cast<VectorType>(LHS->getType()) 203 ->getElementType() 204 ->isFloatingPointTy() 205 ? B.CreateFSub(LHS, RHS) 206 : B.CreateSub(LHS, RHS); 207 } 208 209 /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p 210 /// RHS. CreateScalarMultiply(Value * LHS,Value * RHS)211 Value *CreateScalarMultiply(Value *LHS, Value *RHS) { 212 std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); 213 if (LHS->getType()->getScalarType()->isFloatingPointTy()) 214 return B.CreateFMul(LHS, RHS); 215 return B.CreateMul(LHS, RHS); 216 } 217 218 /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. 219 Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx, 220 unsigned NumRows, Twine const &Name = "") { 221 222 unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), 223 ColumnIdx->getType()->getScalarSizeInBits()); 224 Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); 225 RowIdx = B.CreateZExt(RowIdx, IntTy); 226 ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); 227 Value *NumRowsV = B.getIntN(MaxWidth, NumRows); 228 return B.CreateExtractElement( 229 Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx), 230 "matext"); 231 } 232 }; 233 234 } // end namespace llvm 235 236 #endif // LLVM_IR_MATRIXBUILDER_H 237