1 //===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate SPIRV operations for Vector
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
16 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
24
25 using namespace mlir;
26
27 namespace {
28 struct VectorBroadcastConvert final
29 : public SPIRVOpLowering<vector::BroadcastOp> {
30 using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
31 LogicalResult
matchAndRewrite__anonb8f689180111::VectorBroadcastConvert32 matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
33 ConversionPatternRewriter &rewriter) const override {
34 if (broadcastOp.source().getType().isa<VectorType>() ||
35 !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
36 return failure();
37 vector::BroadcastOp::Adaptor adaptor(operands);
38 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
39 adaptor.source());
40 Value construct = rewriter.create<spirv::CompositeConstructOp>(
41 broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
42 rewriter.replaceOp(broadcastOp, construct);
43 return success();
44 }
45 };
46
47 struct VectorExtractOpConvert final
48 : public SPIRVOpLowering<vector::ExtractOp> {
49 using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
50 LogicalResult
matchAndRewrite__anonb8f689180111::VectorExtractOpConvert51 matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
52 ConversionPatternRewriter &rewriter) const override {
53 if (extractOp.getType().isa<VectorType>() ||
54 !spirv::CompositeType::isValid(extractOp.getVectorType()))
55 return failure();
56 vector::ExtractOp::Adaptor adaptor(operands);
57 int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
58 Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
59 extractOp.getLoc(), adaptor.vector(), id);
60 rewriter.replaceOp(extractOp, newExtract);
61 return success();
62 }
63 };
64
65 struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
66 using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
67 LogicalResult
matchAndRewrite__anonb8f689180111::VectorInsertOpConvert68 matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
69 ConversionPatternRewriter &rewriter) const override {
70 if (insertOp.getSourceType().isa<VectorType>() ||
71 !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
72 return failure();
73 vector::InsertOp::Adaptor adaptor(operands);
74 int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
75 Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
76 insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
77 rewriter.replaceOp(insertOp, newInsert);
78 return success();
79 }
80 };
81
82 struct VectorExtractElementOpConvert final
83 : public SPIRVOpLowering<vector::ExtractElementOp> {
84 using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering;
85 LogicalResult
matchAndRewrite__anonb8f689180111::VectorExtractElementOpConvert86 matchAndRewrite(vector::ExtractElementOp extractElementOp,
87 ArrayRef<Value> operands,
88 ConversionPatternRewriter &rewriter) const override {
89 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
90 return failure();
91 vector::ExtractElementOp::Adaptor adaptor(operands);
92 Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
93 extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
94 extractElementOp.position());
95 rewriter.replaceOp(extractElementOp, newExtractElement);
96 return success();
97 }
98 };
99
100 struct VectorInsertElementOpConvert final
101 : public SPIRVOpLowering<vector::InsertElementOp> {
102 using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
103 LogicalResult
matchAndRewrite__anonb8f689180111::VectorInsertElementOpConvert104 matchAndRewrite(vector::InsertElementOp insertElementOp,
105 ArrayRef<Value> operands,
106 ConversionPatternRewriter &rewriter) const override {
107 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
108 return failure();
109 vector::InsertElementOp::Adaptor adaptor(operands);
110 Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
111 insertElementOp.getLoc(), insertElementOp.getType(),
112 insertElementOp.dest(), adaptor.source(), insertElementOp.position());
113 rewriter.replaceOp(insertElementOp, newInsertElement);
114 return success();
115 }
116 };
117
118 } // namespace
119
populateVectorToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)120 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
121 SPIRVTypeConverter &typeConverter,
122 OwningRewritePatternList &patterns) {
123 patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
124 VectorInsertOpConvert, VectorExtractElementOpConvert,
125 VectorInsertElementOpConvert>(context, typeConverter);
126 }
127
128 namespace {
129 struct LowerVectorToSPIRVPass
130 : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
131 void runOnOperation() override;
132 };
133 } // namespace
134
runOnOperation()135 void LowerVectorToSPIRVPass::runOnOperation() {
136 MLIRContext *context = &getContext();
137 ModuleOp module = getOperation();
138
139 auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
140 std::unique_ptr<ConversionTarget> target =
141 spirv::SPIRVConversionTarget::get(targetAttr);
142
143 SPIRVTypeConverter typeConverter(targetAttr);
144 OwningRewritePatternList patterns;
145 populateVectorToSPIRVPatterns(context, typeConverter, patterns);
146
147 target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
148 target->addLegalOp<FuncOp>();
149
150 if (failed(applyFullConversion(module, *target, std::move(patterns))))
151 return signalPassFailure();
152 }
153
154 std::unique_ptr<OperationPass<ModuleOp>>
createConvertVectorToSPIRVPass()155 mlir::createConvertVectorToSPIRVPass() {
156 return std::make_unique<LowerVectorToSPIRVPass>();
157 }
158