• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to generate SPIRV operations for Vector
10 // operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
16 #include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 
25 using namespace mlir;
26 
27 namespace {
28 struct VectorBroadcastConvert final
29     : public SPIRVOpLowering<vector::BroadcastOp> {
30   using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
31   LogicalResult
matchAndRewrite__anonb8f689180111::VectorBroadcastConvert32   matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
33                   ConversionPatternRewriter &rewriter) const override {
34     if (broadcastOp.source().getType().isa<VectorType>() ||
35         !spirv::CompositeType::isValid(broadcastOp.getVectorType()))
36       return failure();
37     vector::BroadcastOp::Adaptor adaptor(operands);
38     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
39                                  adaptor.source());
40     Value construct = rewriter.create<spirv::CompositeConstructOp>(
41         broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
42     rewriter.replaceOp(broadcastOp, construct);
43     return success();
44   }
45 };
46 
47 struct VectorExtractOpConvert final
48     : public SPIRVOpLowering<vector::ExtractOp> {
49   using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
50   LogicalResult
matchAndRewrite__anonb8f689180111::VectorExtractOpConvert51   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
52                   ConversionPatternRewriter &rewriter) const override {
53     if (extractOp.getType().isa<VectorType>() ||
54         !spirv::CompositeType::isValid(extractOp.getVectorType()))
55       return failure();
56     vector::ExtractOp::Adaptor adaptor(operands);
57     int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
58     Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
59         extractOp.getLoc(), adaptor.vector(), id);
60     rewriter.replaceOp(extractOp, newExtract);
61     return success();
62   }
63 };
64 
65 struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
66   using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
67   LogicalResult
matchAndRewrite__anonb8f689180111::VectorInsertOpConvert68   matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
69                   ConversionPatternRewriter &rewriter) const override {
70     if (insertOp.getSourceType().isa<VectorType>() ||
71         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
72       return failure();
73     vector::InsertOp::Adaptor adaptor(operands);
74     int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
75     Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
76         insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
77     rewriter.replaceOp(insertOp, newInsert);
78     return success();
79   }
80 };
81 
82 struct VectorExtractElementOpConvert final
83     : public SPIRVOpLowering<vector::ExtractElementOp> {
84   using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering;
85   LogicalResult
matchAndRewrite__anonb8f689180111::VectorExtractElementOpConvert86   matchAndRewrite(vector::ExtractElementOp extractElementOp,
87                   ArrayRef<Value> operands,
88                   ConversionPatternRewriter &rewriter) const override {
89     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
90       return failure();
91     vector::ExtractElementOp::Adaptor adaptor(operands);
92     Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
93         extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
94         extractElementOp.position());
95     rewriter.replaceOp(extractElementOp, newExtractElement);
96     return success();
97   }
98 };
99 
100 struct VectorInsertElementOpConvert final
101     : public SPIRVOpLowering<vector::InsertElementOp> {
102   using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
103   LogicalResult
matchAndRewrite__anonb8f689180111::VectorInsertElementOpConvert104   matchAndRewrite(vector::InsertElementOp insertElementOp,
105                   ArrayRef<Value> operands,
106                   ConversionPatternRewriter &rewriter) const override {
107     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
108       return failure();
109     vector::InsertElementOp::Adaptor adaptor(operands);
110     Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
111         insertElementOp.getLoc(), insertElementOp.getType(),
112         insertElementOp.dest(), adaptor.source(), insertElementOp.position());
113     rewriter.replaceOp(insertElementOp, newInsertElement);
114     return success();
115   }
116 };
117 
118 } // namespace
119 
populateVectorToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)120 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
121                                          SPIRVTypeConverter &typeConverter,
122                                          OwningRewritePatternList &patterns) {
123   patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
124                   VectorInsertOpConvert, VectorExtractElementOpConvert,
125                   VectorInsertElementOpConvert>(context, typeConverter);
126 }
127 
128 namespace {
129 struct LowerVectorToSPIRVPass
130     : public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
131   void runOnOperation() override;
132 };
133 } // namespace
134 
runOnOperation()135 void LowerVectorToSPIRVPass::runOnOperation() {
136   MLIRContext *context = &getContext();
137   ModuleOp module = getOperation();
138 
139   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
140   std::unique_ptr<ConversionTarget> target =
141       spirv::SPIRVConversionTarget::get(targetAttr);
142 
143   SPIRVTypeConverter typeConverter(targetAttr);
144   OwningRewritePatternList patterns;
145   populateVectorToSPIRVPatterns(context, typeConverter, patterns);
146 
147   target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
148   target->addLegalOp<FuncOp>();
149 
150   if (failed(applyFullConversion(module, *target, std::move(patterns))))
151     return signalPassFailure();
152 }
153 
154 std::unique_ptr<OperationPass<ModuleOp>>
createConvertVectorToSPIRVPass()155 mlir::createConvertVectorToSPIRVPass() {
156   return std::make_unique<LowerVectorToSPIRVPass>();
157 }
158