1 //===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
10
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19
20 using namespace mlir;
21 using namespace mlir::vector;
22 using namespace mlir::avx512;
23
24 template <typename OpTy>
getSrcVectorElementType(OpTy op)25 static Type getSrcVectorElementType(OpTy op) {
26 return op.src().getType().template cast<VectorType>().getElementType();
27 }
28
29 // TODO: Code is currently copy-pasted and adapted from the code
30 // 1-1 LLVM conversion. It would better if it were properly exposed in core and
31 // reusable.
32 /// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to
33 /// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass
34 /// operands as is, preserve attributes.
35 template <typename SourceOp, typename TargetOp>
36 static LogicalResult
matchAndRewriteOneToOne(const ConvertToLLVMPattern & lowering,LLVMTypeConverter & typeConverter,Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)37 matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering,
38 LLVMTypeConverter &typeConverter, Operation *op,
39 ArrayRef<Value> operands,
40 ConversionPatternRewriter &rewriter) {
41 unsigned numResults = op->getNumResults();
42
43 Type packedType;
44 if (numResults != 0) {
45 packedType = typeConverter.packFunctionResults(op->getResultTypes());
46 if (!packedType)
47 return failure();
48 }
49
50 auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
51 op->getAttrs());
52
53 // If the operation produced 0 or 1 result, return them immediately.
54 if (numResults == 0)
55 return rewriter.eraseOp(op), success();
56 if (numResults == 1)
57 return rewriter.replaceOp(op, newOp->getResult(0)), success();
58
59 // Otherwise, it had been converted to an operation producing a structure.
60 // Extract individual results from the structure and return them as list.
61 SmallVector<Value, 4> results;
62 results.reserve(numResults);
63 for (unsigned i = 0; i < numResults; ++i) {
64 auto type = typeConverter.convertType(op->getResult(i).getType());
65 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
66 op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
67 }
68 rewriter.replaceOp(op, results);
69 return success();
70 }
71
72 namespace {
73 // TODO: Patterns are too verbose due to the fact that we have 1 op (e.g.
74 // MaskRndScaleOp) and different possible target ops. It would be better to take
75 // a Functor so that all these conversions become 1-liners.
76 struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
MaskRndScaleOpPS512Conversion__anon14c4020f0111::MaskRndScaleOpPS512Conversion77 explicit MaskRndScaleOpPS512Conversion(MLIRContext *context,
78 LLVMTypeConverter &typeConverter)
79 : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
80 typeConverter) {}
81
82 LogicalResult
matchAndRewrite__anon14c4020f0111::MaskRndScaleOpPS512Conversion83 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
84 ConversionPatternRewriter &rewriter) const override {
85 if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF32())
86 return failure();
87 return matchAndRewriteOneToOne<MaskRndScaleOp,
88 LLVM::x86_avx512_mask_rndscale_ps_512>(
89 *this, *getTypeConverter(), op, operands, rewriter);
90 }
91 };
92
93 struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
MaskRndScaleOpPD512Conversion__anon14c4020f0111::MaskRndScaleOpPD512Conversion94 explicit MaskRndScaleOpPD512Conversion(MLIRContext *context,
95 LLVMTypeConverter &typeConverter)
96 : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
97 typeConverter) {}
98
99 LogicalResult
matchAndRewrite__anon14c4020f0111::MaskRndScaleOpPD512Conversion100 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101 ConversionPatternRewriter &rewriter) const override {
102 if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF64())
103 return failure();
104 return matchAndRewriteOneToOne<MaskRndScaleOp,
105 LLVM::x86_avx512_mask_rndscale_pd_512>(
106 *this, *getTypeConverter(), op, operands, rewriter);
107 }
108 };
109
110 struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
ScaleFOpPS512Conversion__anon14c4020f0111::ScaleFOpPS512Conversion111 explicit ScaleFOpPS512Conversion(MLIRContext *context,
112 LLVMTypeConverter &typeConverter)
113 : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
114 typeConverter) {}
115
116 LogicalResult
matchAndRewrite__anon14c4020f0111::ScaleFOpPS512Conversion117 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
118 ConversionPatternRewriter &rewriter) const override {
119 if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF32())
120 return failure();
121 return matchAndRewriteOneToOne<MaskScaleFOp,
122 LLVM::x86_avx512_mask_scalef_ps_512>(
123 *this, *getTypeConverter(), op, operands, rewriter);
124 }
125 };
126
127 struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
ScaleFOpPD512Conversion__anon14c4020f0111::ScaleFOpPD512Conversion128 explicit ScaleFOpPD512Conversion(MLIRContext *context,
129 LLVMTypeConverter &typeConverter)
130 : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
131 typeConverter) {}
132
133 LogicalResult
matchAndRewrite__anon14c4020f0111::ScaleFOpPD512Conversion134 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
135 ConversionPatternRewriter &rewriter) const override {
136 if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF64())
137 return failure();
138 return matchAndRewriteOneToOne<MaskScaleFOp,
139 LLVM::x86_avx512_mask_scalef_pd_512>(
140 *this, *getTypeConverter(), op, operands, rewriter);
141 }
142 };
143 } // namespace
144
145 /// Populate the given list with patterns that convert from AVX512 to LLVM.
populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)146 void mlir::populateAVX512ToLLVMConversionPatterns(
147 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
148 MLIRContext *ctx = converter.getDialect()->getContext();
149 // clang-format off
150 patterns.insert<MaskRndScaleOpPS512Conversion,
151 MaskRndScaleOpPD512Conversion,
152 ScaleFOpPS512Conversion,
153 ScaleFOpPD512Conversion>(ctx, converter);
154 // clang-format on
155 }
156