• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/Vector/VectorOps.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 using namespace mlir::avx512;
23 
24 template <typename OpTy>
getSrcVectorElementType(OpTy op)25 static Type getSrcVectorElementType(OpTy op) {
26   return op.src().getType().template cast<VectorType>().getElementType();
27 }
28 
29 // TODO: Code is currently copy-pasted and adapted from the code
30 // 1-1 LLVM conversion. It would better if it were properly exposed in core and
31 // reusable.
32 /// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to
33 /// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass
34 /// operands as is, preserve attributes.
35 template <typename SourceOp, typename TargetOp>
36 static LogicalResult
matchAndRewriteOneToOne(const ConvertToLLVMPattern & lowering,LLVMTypeConverter & typeConverter,Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter)37 matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering,
38                         LLVMTypeConverter &typeConverter, Operation *op,
39                         ArrayRef<Value> operands,
40                         ConversionPatternRewriter &rewriter) {
41   unsigned numResults = op->getNumResults();
42 
43   Type packedType;
44   if (numResults != 0) {
45     packedType = typeConverter.packFunctionResults(op->getResultTypes());
46     if (!packedType)
47       return failure();
48   }
49 
50   auto newOp = rewriter.create<TargetOp>(op->getLoc(), packedType, operands,
51                                          op->getAttrs());
52 
53   // If the operation produced 0 or 1 result, return them immediately.
54   if (numResults == 0)
55     return rewriter.eraseOp(op), success();
56   if (numResults == 1)
57     return rewriter.replaceOp(op, newOp->getResult(0)), success();
58 
59   // Otherwise, it had been converted to an operation producing a structure.
60   // Extract individual results from the structure and return them as list.
61   SmallVector<Value, 4> results;
62   results.reserve(numResults);
63   for (unsigned i = 0; i < numResults; ++i) {
64     auto type = typeConverter.convertType(op->getResult(i).getType());
65     results.push_back(rewriter.create<LLVM::ExtractValueOp>(
66         op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i)));
67   }
68   rewriter.replaceOp(op, results);
69   return success();
70 }
71 
72 namespace {
73 // TODO: Patterns are too verbose due to the fact that we have 1 op (e.g.
74 // MaskRndScaleOp) and different possible target ops. It would be better to take
75 // a Functor so that all these conversions become 1-liners.
76 struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern {
MaskRndScaleOpPS512Conversion__anon14c4020f0111::MaskRndScaleOpPS512Conversion77   explicit MaskRndScaleOpPS512Conversion(MLIRContext *context,
78                                          LLVMTypeConverter &typeConverter)
79       : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
80                              typeConverter) {}
81 
82   LogicalResult
matchAndRewrite__anon14c4020f0111::MaskRndScaleOpPS512Conversion83   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
84                   ConversionPatternRewriter &rewriter) const override {
85     if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF32())
86       return failure();
87     return matchAndRewriteOneToOne<MaskRndScaleOp,
88                                    LLVM::x86_avx512_mask_rndscale_ps_512>(
89         *this, *getTypeConverter(), op, operands, rewriter);
90   }
91 };
92 
93 struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern {
MaskRndScaleOpPD512Conversion__anon14c4020f0111::MaskRndScaleOpPD512Conversion94   explicit MaskRndScaleOpPD512Conversion(MLIRContext *context,
95                                          LLVMTypeConverter &typeConverter)
96       : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context,
97                              typeConverter) {}
98 
99   LogicalResult
matchAndRewrite__anon14c4020f0111::MaskRndScaleOpPD512Conversion100   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
101                   ConversionPatternRewriter &rewriter) const override {
102     if (!getSrcVectorElementType(cast<MaskRndScaleOp>(op)).isF64())
103       return failure();
104     return matchAndRewriteOneToOne<MaskRndScaleOp,
105                                    LLVM::x86_avx512_mask_rndscale_pd_512>(
106         *this, *getTypeConverter(), op, operands, rewriter);
107   }
108 };
109 
110 struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern {
ScaleFOpPS512Conversion__anon14c4020f0111::ScaleFOpPS512Conversion111   explicit ScaleFOpPS512Conversion(MLIRContext *context,
112                                    LLVMTypeConverter &typeConverter)
113       : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
114                              typeConverter) {}
115 
116   LogicalResult
matchAndRewrite__anon14c4020f0111::ScaleFOpPS512Conversion117   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
118                   ConversionPatternRewriter &rewriter) const override {
119     if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF32())
120       return failure();
121     return matchAndRewriteOneToOne<MaskScaleFOp,
122                                    LLVM::x86_avx512_mask_scalef_ps_512>(
123         *this, *getTypeConverter(), op, operands, rewriter);
124   }
125 };
126 
127 struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern {
ScaleFOpPD512Conversion__anon14c4020f0111::ScaleFOpPD512Conversion128   explicit ScaleFOpPD512Conversion(MLIRContext *context,
129                                    LLVMTypeConverter &typeConverter)
130       : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context,
131                              typeConverter) {}
132 
133   LogicalResult
matchAndRewrite__anon14c4020f0111::ScaleFOpPD512Conversion134   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
135                   ConversionPatternRewriter &rewriter) const override {
136     if (!getSrcVectorElementType(cast<MaskScaleFOp>(op)).isF64())
137       return failure();
138     return matchAndRewriteOneToOne<MaskScaleFOp,
139                                    LLVM::x86_avx512_mask_scalef_pd_512>(
140         *this, *getTypeConverter(), op, operands, rewriter);
141   }
142 };
143 } // namespace
144 
145 /// Populate the given list with patterns that convert from AVX512 to LLVM.
populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter & converter,OwningRewritePatternList & patterns)146 void mlir::populateAVX512ToLLVMConversionPatterns(
147     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
148   MLIRContext *ctx = converter.getDialect()->getContext();
149   // clang-format off
150   patterns.insert<MaskRndScaleOpPS512Conversion,
151                   MaskRndScaleOpPD512Conversion,
152                   ScaleFOpPS512Conversion,
153                   ScaleFOpPD512Conversion>(ctx, converter);
154   // clang-format on
155 }
156