1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Dialect/GPU/GPUDialect.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Builders.h" 16 17 namespace mlir { 18 19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` 20 /// depending on the element type that Op operates upon. The function 21 /// declaration is added in case it was not added before. 22 /// 23 /// If the input values are of f16 type, the value is first casted to f32, the 24 /// function called and then the result casted back. 25 /// 26 /// Example with NVVM: 27 /// %exp_f32 = std.exp %arg_f32 : f32 28 /// 29 /// will be transformed into 30 /// llvm.call @__nv_expf(%arg_f32) : (!llvm.float) -> !llvm.float 31 template <typename SourceOp> 32 struct OpToFuncCallLowering : public ConvertToLLVMPattern { 33 public: OpToFuncCallLoweringOpToFuncCallLowering34 explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, 35 StringRef f64Func) 36 : ConvertToLLVMPattern(SourceOp::getOperationName(), 37 lowering_.getDialect()->getContext(), lowering_), 38 f32Func(f32Func), f64Func(f64Func) {} 39 40 LogicalResult matchAndRewriteOpToFuncCallLowering41 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 42 ConversionPatternRewriter &rewriter) const override { 43 using LLVM::LLVMFuncOp; 44 using LLVM::LLVMType; 45 46 static_assert( 47 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 48 "expected single result op"); 49 50 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, 51 SourceOp>::value, 52 "expected op with same operand and result types"); 53 54 SmallVector<Value, 1> castedOperands; 55 for (Value operand : operands) 56 castedOperands.push_back(maybeCast(operand, rewriter)); 57 58 LLVMType resultType = 59 castedOperands.front().getType().cast<LLVM::LLVMType>(); 60 LLVMType funcType = getFunctionType(resultType, castedOperands); 61 StringRef funcName = getFunctionName(funcType.getFunctionResultType()); 62 if (funcName.empty()) 63 return failure(); 64 65 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); 66 auto callOp = rewriter.create<LLVM::CallOp>( 67 op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), 68 castedOperands); 69 70 if (resultType == operands.front().getType()) { 71 rewriter.replaceOp(op, {callOp.getResult(0)}); 72 return success(); 73 } 74 75 Value truncated = rewriter.create<LLVM::FPTruncOp>( 76 op->getLoc(), operands.front().getType(), callOp.getResult(0)); 77 rewriter.replaceOp(op, {truncated}); 78 return success(); 79 } 80 81 private: maybeCastOpToFuncCallLowering82 Value maybeCast(Value operand, PatternRewriter &rewriter) const { 83 LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>(); 84 if (!type.isHalfTy()) 85 return operand; 86 87 return rewriter.create<LLVM::FPExtOp>( 88 operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()), 89 operand); 90 } 91 getFunctionTypeOpToFuncCallLowering92 LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, 93 ArrayRef<Value> operands) const { 94 using LLVM::LLVMType; 95 SmallVector<LLVMType, 1> operandTypes; 96 for (Value operand : operands) { 97 operandTypes.push_back(operand.getType().cast<LLVMType>()); 98 } 99 return LLVMType::getFunctionTy(resultType, operandTypes, 100 /*isVarArg=*/false); 101 } 102 getFunctionNameOpToFuncCallLowering103 StringRef getFunctionName(LLVM::LLVMType type) const { 104 if (type.isFloatTy()) 105 return f32Func; 106 if (type.isDoubleTy()) 107 return f64Func; 108 return ""; 109 } 110 appendOrGetFuncOpOpToFuncCallLowering111 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, 112 LLVM::LLVMType funcType, 113 Operation *op) const { 114 using LLVM::LLVMFuncOp; 115 116 Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName); 117 if (funcOp) 118 return cast<LLVMFuncOp>(*funcOp); 119 120 mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>()); 121 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType); 122 } 123 124 const std::string f32Func; 125 const std::string f64Func; 126 }; 127 128 } // namespace mlir 129 130 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 131