1 //===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
18 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29
30 using namespace mlir;
31
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35
36 /// Returns true if the given type is a signed integer or vector type.
isSignedIntegerOrVector(Type type)37 static bool isSignedIntegerOrVector(Type type) {
38 if (type.isSignedInteger())
39 return true;
40 if (auto vecType = type.dyn_cast<VectorType>())
41 return vecType.getElementType().isSignedInteger();
42 return false;
43 }
44
45 /// Returns true if the given type is an unsigned integer or vector type
isUnsignedIntegerOrVector(Type type)46 static bool isUnsignedIntegerOrVector(Type type) {
47 if (type.isUnsignedInteger())
48 return true;
49 if (auto vecType = type.dyn_cast<VectorType>())
50 return vecType.getElementType().isUnsignedInteger();
51 return false;
52 }
53
54 /// Returns the bit width of integer, float or vector of float or integer values
getBitWidth(Type type)55 static unsigned getBitWidth(Type type) {
56 assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57 "bitwidth is not supported for this type");
58 if (type.isIntOrFloat())
59 return type.getIntOrFloatBitWidth();
60 auto vecType = type.dyn_cast<VectorType>();
61 auto elementType = vecType.getElementType();
62 assert(elementType.isIntOrFloat() &&
63 "only integers and floats have a bitwidth");
64 return elementType.getIntOrFloatBitWidth();
65 }
66
67 /// Returns the bit width of LLVMType integer or vector.
getLLVMTypeBitWidth(LLVM::LLVMType type)68 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
69 return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
70 : type.getIntegerBitWidth();
71 }
72
73 /// Creates `IntegerAttribute` with all bits set for given type
minusOneIntegerAttribute(Type type,Builder builder)74 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
75 if (auto vecType = type.dyn_cast<VectorType>()) {
76 auto integerType = vecType.getElementType().cast<IntegerType>();
77 return builder.getIntegerAttr(integerType, -1);
78 }
79 auto integerType = type.cast<IntegerType>();
80 return builder.getIntegerAttr(integerType, -1);
81 }
82
83 /// Creates `llvm.mlir.constant` with all bits set for the given type.
createConstantAllBitsSet(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter)84 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
85 PatternRewriter &rewriter) {
86 if (srcType.isa<VectorType>()) {
87 return rewriter.create<LLVM::ConstantOp>(
88 loc, dstType,
89 SplatElementsAttr::get(srcType.cast<ShapedType>(),
90 minusOneIntegerAttribute(srcType, rewriter)));
91 }
92 return rewriter.create<LLVM::ConstantOp>(
93 loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
94 }
95
96 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
createFPConstant(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter,double value)97 static Value createFPConstant(Location loc, Type srcType, Type dstType,
98 PatternRewriter &rewriter, double value) {
99 if (auto vecType = srcType.dyn_cast<VectorType>()) {
100 auto floatType = vecType.getElementType().cast<FloatType>();
101 return rewriter.create<LLVM::ConstantOp>(
102 loc, dstType,
103 SplatElementsAttr::get(vecType,
104 rewriter.getFloatAttr(floatType, value)));
105 }
106 auto floatType = srcType.cast<FloatType>();
107 return rewriter.create<LLVM::ConstantOp>(
108 loc, dstType, rewriter.getFloatAttr(floatType, value));
109 }
110
111 /// Utility function for bitfield ops:
112 /// - `BitFieldInsert`
113 /// - `BitFieldSExtract`
114 /// - `BitFieldUExtract`
115 /// Truncates or extends the value. If the bitwidth of the value is the same as
116 /// `dstType` bitwidth, the value remains unchanged.
optionallyTruncateOrExtend(Location loc,Value value,Type dstType,PatternRewriter & rewriter)117 static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
118 PatternRewriter &rewriter) {
119 auto srcType = value.getType();
120 auto llvmType = dstType.cast<LLVM::LLVMType>();
121 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
122 unsigned valueBitWidth =
123 srcType.isa<LLVM::LLVMType>()
124 ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
125 : getBitWidth(srcType);
126
127 if (valueBitWidth < targetBitWidth)
128 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
129 // If the bit widths of `Count` and `Offset` are greater than the bit width
130 // of the target type, they are truncated. Truncation is safe since `Count`
131 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
132 // both values can be expressed in 8 bits.
133 if (valueBitWidth > targetBitWidth)
134 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
135 return value;
136 }
137
138 /// Broadcasts the value to vector with `numElements` number of elements.
broadcast(Location loc,Value toBroadcast,unsigned numElements,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)139 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
140 LLVMTypeConverter &typeConverter,
141 ConversionPatternRewriter &rewriter) {
142 auto vectorType = VectorType::get(numElements, toBroadcast.getType());
143 auto llvmVectorType = typeConverter.convertType(vectorType);
144 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
145 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
146 for (unsigned i = 0; i < numElements; ++i) {
147 auto index = rewriter.create<LLVM::ConstantOp>(
148 loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
149 broadcasted = rewriter.create<LLVM::InsertElementOp>(
150 loc, llvmVectorType, broadcasted, toBroadcast, index);
151 }
152 return broadcasted;
153 }
154
155 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
optionallyBroadcast(Location loc,Value value,Type srcType,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)156 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
157 LLVMTypeConverter &typeConverter,
158 ConversionPatternRewriter &rewriter) {
159 if (auto vectorType = srcType.dyn_cast<VectorType>()) {
160 unsigned numElements = vectorType.getNumElements();
161 return broadcast(loc, value, numElements, typeConverter, rewriter);
162 }
163 return value;
164 }
165
166 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
167 /// `BitFieldUExtract`.
168 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
169 /// a vector type, construct a vector that has:
170 /// - same number of elements as `Base`
171 /// - each element has the type that is the same as the type of `Offset` or
172 /// `Count`
173 /// - each element has the same value as `Offset` or `Count`
174 /// Then cast `Offset` and `Count` if their bit width is different
175 /// from `Base` bit width.
processCountOrOffset(Location loc,Value value,Type srcType,Type dstType,LLVMTypeConverter & converter,ConversionPatternRewriter & rewriter)176 static Value processCountOrOffset(Location loc, Value value, Type srcType,
177 Type dstType, LLVMTypeConverter &converter,
178 ConversionPatternRewriter &rewriter) {
179 Value broadcasted =
180 optionallyBroadcast(loc, value, srcType, converter, rewriter);
181 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
182 }
183
184 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
185 /// offset to LLVM struct. Otherwise, the conversion is not supported.
186 static Optional<Type>
convertStructTypeWithOffset(spirv::StructType type,LLVMTypeConverter & converter)187 convertStructTypeWithOffset(spirv::StructType type,
188 LLVMTypeConverter &converter) {
189 if (type != VulkanLayoutUtils::decorateType(type))
190 return llvm::None;
191
192 auto elementsVector = llvm::to_vector<8>(
193 llvm::map_range(type.getElementTypes(), [&](Type elementType) {
194 return converter.convertType(elementType).cast<LLVM::LLVMType>();
195 }));
196 return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
197 /*isPacked=*/false);
198 }
199
200 /// Converts SPIR-V struct with no offset to packed LLVM struct.
convertStructTypePacked(spirv::StructType type,LLVMTypeConverter & converter)201 static Type convertStructTypePacked(spirv::StructType type,
202 LLVMTypeConverter &converter) {
203 auto elementsVector = llvm::to_vector<8>(
204 llvm::map_range(type.getElementTypes(), [&](Type elementType) {
205 return converter.convertType(elementType).cast<LLVM::LLVMType>();
206 }));
207 return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
208 /*isPacked=*/true);
209 }
210
211 /// Creates LLVM dialect constant with the given value.
createI32ConstantOf(Location loc,PatternRewriter & rewriter,unsigned value)212 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
213 unsigned value) {
214 return rewriter.create<LLVM::ConstantOp>(
215 loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
216 rewriter.getIntegerAttr(rewriter.getI32Type(), value));
217 }
218
219 /// Utility for `spv.Load` and `spv.Store` conversion.
replaceWithLoadOrStore(Operation * op,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,unsigned alignment,bool isVolatile,bool isNonTemporal)220 static LogicalResult replaceWithLoadOrStore(Operation *op,
221 ConversionPatternRewriter &rewriter,
222 LLVMTypeConverter &typeConverter,
223 unsigned alignment, bool isVolatile,
224 bool isNonTemporal) {
225 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
226 auto dstType = typeConverter.convertType(loadOp.getType());
227 if (!dstType)
228 return failure();
229 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
230 loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal);
231 return success();
232 }
233 auto storeOp = cast<spirv::StoreOp>(op);
234 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(),
235 storeOp.ptr(), alignment,
236 isVolatile, isNonTemporal);
237 return success();
238 }
239
240 //===----------------------------------------------------------------------===//
241 // Type conversion
242 //===----------------------------------------------------------------------===//
243
244 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
245 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
246 /// when converting ops that manipulate array types.
convertArrayType(spirv::ArrayType type,TypeConverter & converter)247 static Optional<Type> convertArrayType(spirv::ArrayType type,
248 TypeConverter &converter) {
249 unsigned stride = type.getArrayStride();
250 Type elementType = type.getElementType();
251 auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
252 if (stride != 0 &&
253 !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
254 return llvm::None;
255
256 auto llvmElementType =
257 converter.convertType(elementType).cast<LLVM::LLVMType>();
258 unsigned numElements = type.getNumElements();
259 return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
260 }
261
262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
263 /// modelled at the moment.
convertPointerType(spirv::PointerType type,TypeConverter & converter)264 static Type convertPointerType(spirv::PointerType type,
265 TypeConverter &converter) {
266 auto pointeeType =
267 converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
268 return pointeeType.getPointerTo();
269 }
270
271 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
272 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
273 /// no modelling of array stride at the moment.
convertRuntimeArrayType(spirv::RuntimeArrayType type,TypeConverter & converter)274 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
275 TypeConverter &converter) {
276 if (type.getArrayStride() != 0)
277 return llvm::None;
278 auto elementType =
279 converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
280 return LLVM::LLVMType::getArrayTy(elementType, 0);
281 }
282
283 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
284 /// member decorations. Also, only natural offset is supported.
convertStructType(spirv::StructType type,LLVMTypeConverter & converter)285 static Optional<Type> convertStructType(spirv::StructType type,
286 LLVMTypeConverter &converter) {
287 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
288 type.getMemberDecorations(memberDecorations);
289 if (!memberDecorations.empty())
290 return llvm::None;
291 if (type.hasOffset())
292 return convertStructTypeWithOffset(type, converter);
293 return convertStructTypePacked(type, converter);
294 }
295
296 //===----------------------------------------------------------------------===//
297 // Operation conversion
298 //===----------------------------------------------------------------------===//
299
300 namespace {
301
302 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
303 public:
304 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
305
306 LogicalResult
matchAndRewrite(spirv::AccessChainOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const307 matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
308 ConversionPatternRewriter &rewriter) const override {
309 auto dstType = typeConverter.convertType(op.component_ptr().getType());
310 if (!dstType)
311 return failure();
312 // To use GEP we need to add a first 0 index to go through the pointer.
313 auto indices = llvm::to_vector<4>(op.indices());
314 Type indexType = op.indices().front().getType();
315 auto llvmIndexType = typeConverter.convertType(indexType);
316 if (!llvmIndexType)
317 return failure();
318 Value zero = rewriter.create<LLVM::ConstantOp>(
319 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
320 indices.insert(indices.begin(), zero);
321 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
322 indices);
323 return success();
324 }
325 };
326
327 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
328 public:
329 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
330
331 LogicalResult
matchAndRewrite(spirv::AddressOfOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const332 matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands,
333 ConversionPatternRewriter &rewriter) const override {
334 auto dstType = typeConverter.convertType(op.pointer().getType());
335 if (!dstType)
336 return failure();
337 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
338 op, dstType.cast<LLVM::LLVMType>(), op.variable());
339 return success();
340 }
341 };
342
343 class BitFieldInsertPattern
344 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
345 public:
346 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
347
348 LogicalResult
matchAndRewrite(spirv::BitFieldInsertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const349 matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
350 ConversionPatternRewriter &rewriter) const override {
351 auto srcType = op.getType();
352 auto dstType = typeConverter.convertType(srcType);
353 if (!dstType)
354 return failure();
355 Location loc = op.getLoc();
356
357 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
358 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
359 typeConverter, rewriter);
360 Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
361 typeConverter, rewriter);
362
363 // Create a mask with bits set outside [Offset, Offset + Count - 1].
364 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
365 Value maskShiftedByCount =
366 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
367 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
368 maskShiftedByCount, minusOne);
369 Value maskShiftedByCountAndOffset =
370 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
371 Value mask = rewriter.create<LLVM::XOrOp>(
372 loc, dstType, maskShiftedByCountAndOffset, minusOne);
373
374 // Extract unchanged bits from the `Base` that are outside of
375 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
376 Value baseAndMask =
377 rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
378 Value insertShiftedByOffset =
379 rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
380 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
381 insertShiftedByOffset);
382 return success();
383 }
384 };
385
386 /// Converts SPIR-V ConstantOp with scalar or vector type.
387 class ConstantScalarAndVectorPattern
388 : public SPIRVToLLVMConversion<spirv::ConstantOp> {
389 public:
390 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
391
392 LogicalResult
matchAndRewrite(spirv::ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const393 matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
394 ConversionPatternRewriter &rewriter) const override {
395 auto srcType = constOp.getType();
396 if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
397 return failure();
398
399 auto dstType = typeConverter.convertType(srcType);
400 if (!dstType)
401 return failure();
402
403 // SPIR-V constant can be a signed/unsigned integer, which has to be
404 // casted to signless integer when converting to LLVM dialect. Removing the
405 // sign bit may have unexpected behaviour. However, it is better to handle
406 // it case-by-case, given that the purpose of the conversion is not to
407 // cover all possible corner cases.
408 if (isSignedIntegerOrVector(srcType) ||
409 isUnsignedIntegerOrVector(srcType)) {
410 auto *context = rewriter.getContext();
411 auto signlessType = IntegerType::get(getBitWidth(srcType), context);
412
413 if (srcType.isa<VectorType>()) {
414 auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
415 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
416 constOp, dstType,
417 dstElementsAttr.mapValues(
418 signlessType, [&](const APInt &value) { return value; }));
419 return success();
420 }
421 auto srcAttr = constOp.value().cast<IntegerAttr>();
422 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
423 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
424 return success();
425 }
426 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
427 constOp.getAttrs());
428 return success();
429 }
430 };
431
432 class BitFieldSExtractPattern
433 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
434 public:
435 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
436
437 LogicalResult
matchAndRewrite(spirv::BitFieldSExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const438 matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
439 ConversionPatternRewriter &rewriter) const override {
440 auto srcType = op.getType();
441 auto dstType = typeConverter.convertType(srcType);
442 if (!dstType)
443 return failure();
444 Location loc = op.getLoc();
445
446 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
447 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
448 typeConverter, rewriter);
449 Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
450 typeConverter, rewriter);
451
452 // Create a constant that holds the size of the `Base`.
453 IntegerType integerType;
454 if (auto vecType = srcType.dyn_cast<VectorType>())
455 integerType = vecType.getElementType().cast<IntegerType>();
456 else
457 integerType = srcType.cast<IntegerType>();
458
459 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
460 Value size =
461 srcType.isa<VectorType>()
462 ? rewriter.create<LLVM::ConstantOp>(
463 loc, dstType,
464 SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
465 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
466
467 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
468 // at Offset + Count - 1 is the most significant bit now.
469 Value countPlusOffset =
470 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
471 Value amountToShiftLeft =
472 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
473 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
474 loc, dstType, op.base(), amountToShiftLeft);
475
476 // Shift the result right, filling the bits with the sign bit.
477 Value amountToShiftRight =
478 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
479 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
480 amountToShiftRight);
481 return success();
482 }
483 };
484
485 class BitFieldUExtractPattern
486 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
487 public:
488 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
489
490 LogicalResult
matchAndRewrite(spirv::BitFieldUExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const491 matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
492 ConversionPatternRewriter &rewriter) const override {
493 auto srcType = op.getType();
494 auto dstType = typeConverter.convertType(srcType);
495 if (!dstType)
496 return failure();
497 Location loc = op.getLoc();
498
499 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
500 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
501 typeConverter, rewriter);
502 Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
503 typeConverter, rewriter);
504
505 // Create a mask with bits set at [0, Count - 1].
506 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
507 Value maskShiftedByCount =
508 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
509 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
510 minusOne);
511
512 // Shift `Base` by `Offset` and apply the mask on it.
513 Value shiftedBase =
514 rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
515 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
516 return success();
517 }
518 };
519
520 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
521 public:
522 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
523
524 LogicalResult
matchAndRewrite(spirv::BranchOp branchOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const525 matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands,
526 ConversionPatternRewriter &rewriter) const override {
527 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands,
528 branchOp.getTarget());
529 return success();
530 }
531 };
532
533 class BranchConditionalConversionPattern
534 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
535 public:
536 using SPIRVToLLVMConversion<
537 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
538
539 LogicalResult
matchAndRewrite(spirv::BranchConditionalOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const540 matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
541 ConversionPatternRewriter &rewriter) const override {
542 // If branch weights exist, map them to 32-bit integer vector.
543 ElementsAttr branchWeights = nullptr;
544 if (auto weights = op.branch_weights()) {
545 VectorType weightType = VectorType::get(2, rewriter.getI32Type());
546 branchWeights =
547 DenseElementsAttr::get(weightType, weights.getValue().getValue());
548 }
549
550 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
551 op, op.condition(), op.getTrueBlockArguments(),
552 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
553 op.getFalseBlock());
554 return success();
555 }
556 };
557
558 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
559 /// is an aggregate type (struct or array). Otherwise, converts to
560 /// `llvm.extractelement` that operates on vectors.
561 class CompositeExtractPattern
562 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
563 public:
564 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
565
566 LogicalResult
matchAndRewrite(spirv::CompositeExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const567 matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands,
568 ConversionPatternRewriter &rewriter) const override {
569 auto dstType = this->typeConverter.convertType(op.getType());
570 if (!dstType)
571 return failure();
572
573 Type containerType = op.composite().getType();
574 if (containerType.isa<VectorType>()) {
575 Location loc = op.getLoc();
576 IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
577 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
578 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
579 op, dstType, op.composite(), index);
580 return success();
581 }
582 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
583 op, dstType, op.composite(), op.indices());
584 return success();
585 }
586 };
587
588 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
589 /// is an aggregate type (struct or array). Otherwise, converts to
590 /// `llvm.insertelement` that operates on vectors.
591 class CompositeInsertPattern
592 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
593 public:
594 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
595
596 LogicalResult
matchAndRewrite(spirv::CompositeInsertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const597 matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands,
598 ConversionPatternRewriter &rewriter) const override {
599 auto dstType = this->typeConverter.convertType(op.getType());
600 if (!dstType)
601 return failure();
602
603 Type containerType = op.composite().getType();
604 if (containerType.isa<VectorType>()) {
605 Location loc = op.getLoc();
606 IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
607 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
608 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
609 op, dstType, op.composite(), op.object(), index);
610 return success();
611 }
612 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
613 op, dstType, op.composite(), op.object(), op.indices());
614 return success();
615 }
616 };
617
618 /// Converts SPIR-V operations that have straightforward LLVM equivalent
619 /// into LLVM dialect operations.
620 template <typename SPIRVOp, typename LLVMOp>
621 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
622 public:
623 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
624
625 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const626 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
627 ConversionPatternRewriter &rewriter) const override {
628 auto dstType = this->typeConverter.convertType(operation.getType());
629 if (!dstType)
630 return failure();
631 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands,
632 operation.getAttrs());
633 return success();
634 }
635 };
636
637 /// Converts `spv.ExecutionMode` into a global struct constant that holds
638 /// execution mode information.
639 class ExecutionModePattern
640 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
641 public:
642 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
643
644 LogicalResult
matchAndRewrite(spirv::ExecutionModeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const645 matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands,
646 ConversionPatternRewriter &rewriter) const override {
647 // First, create the global struct's name that would be associated with
648 // this entry point's execution mode. We set it to be:
649 // __spv__{SPIR-V module name}_{function name}_execution_mode_info
650 ModuleOp module = op->getParentOfType<ModuleOp>();
651 std::string moduleName;
652 if (module.getName().hasValue())
653 moduleName = "_" + module.getName().getValue().str();
654 else
655 moduleName = "";
656 std::string executionModeInfoName = llvm::formatv(
657 "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str());
658
659 MLIRContext *context = rewriter.getContext();
660 OpBuilder::InsertionGuard guard(rewriter);
661 rewriter.setInsertionPointToStart(module.getBody());
662
663 // Create a struct type, corresponding to the C struct below.
664 // struct {
665 // int32_t executionMode;
666 // int32_t values[]; // optional values
667 // };
668 auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context);
669 SmallVector<LLVM::LLVMType, 2> fields;
670 fields.push_back(llvmI32Type);
671 ArrayAttr values = op.values();
672 if (!values.empty()) {
673 auto arrayType = LLVM::LLVMType::getArrayTy(llvmI32Type, values.size());
674 fields.push_back(arrayType);
675 }
676 auto structType = LLVM::LLVMType::getStructTy(context, fields);
677
678 // Create `llvm.mlir.global` with initializer region containing one block.
679 auto global = rewriter.create<LLVM::GlobalOp>(
680 UnknownLoc::get(context), structType, /*isConstant=*/true,
681 LLVM::Linkage::External, executionModeInfoName, Attribute());
682 Location loc = global.getLoc();
683 Region ®ion = global.getInitializerRegion();
684 Block *block = rewriter.createBlock(®ion);
685
686 // Initialize the struct and set the execution mode value.
687 rewriter.setInsertionPoint(block, block->begin());
688 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
689 IntegerAttr executionModeAttr = op.execution_modeAttr();
690 Value executionMode =
691 rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
692 structValue = rewriter.create<LLVM::InsertValueOp>(
693 loc, structType, structValue, executionMode,
694 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
695 context));
696
697 // Insert extra operands if they exist into execution mode info struct.
698 for (unsigned i = 0, e = values.size(); i < e; ++i) {
699 auto attr = values.getValue()[i];
700 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
701 structValue = rewriter.create<LLVM::InsertValueOp>(
702 loc, structType, structValue, entry,
703 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
704 rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
705 context));
706 }
707 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
708 rewriter.eraseOp(op);
709 return success();
710 }
711 };
712
713 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global
714 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
715 /// This difference is handled by `spv.mlir.addressof` and
716 /// `llvm.mlir.addressof`ops that both return a pointer.
717 class GlobalVariablePattern
718 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
719 public:
720 using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
721
722 LogicalResult
matchAndRewrite(spirv::GlobalVariableOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const723 matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands,
724 ConversionPatternRewriter &rewriter) const override {
725 // Currently, there is no support of initialization with a constant value in
726 // SPIR-V dialect. Specialization constants are not considered as well.
727 if (op.initializer())
728 return failure();
729
730 auto srcType = op.type().cast<spirv::PointerType>();
731 auto dstType = typeConverter.convertType(srcType.getPointeeType());
732 if (!dstType)
733 return failure();
734
735 // Limit conversion to the current invocation only or `StorageBuffer`
736 // required by SPIR-V runner.
737 // This is okay because multiple invocations are not supported yet.
738 auto storageClass = srcType.getStorageClass();
739 if (storageClass != spirv::StorageClass::Input &&
740 storageClass != spirv::StorageClass::Private &&
741 storageClass != spirv::StorageClass::Output &&
742 storageClass != spirv::StorageClass::StorageBuffer) {
743 return failure();
744 }
745
746 // LLVM dialect spec: "If the global value is a constant, storing into it is
747 // not allowed.". This corresponds to SPIR-V 'Input' storage class that is
748 // read-only.
749 bool isConstant = storageClass == spirv::StorageClass::Input;
750 // SPIR-V spec: "By default, functions and global variables are private to a
751 // module and cannot be accessed by other modules. However, a module may be
752 // written to export or import functions and global (module scope)
753 // variables.". Therefore, map 'Private' storage class to private linkage,
754 // 'Input' and 'Output' to external linkage.
755 auto linkage = storageClass == spirv::StorageClass::Private
756 ? LLVM::Linkage::Private
757 : LLVM::Linkage::External;
758 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
759 op, dstType.cast<LLVM::LLVMType>(), isConstant, linkage, op.sym_name(),
760 Attribute());
761 return success();
762 }
763 };
764
765 /// Converts SPIR-V cast ops that do not have straightforward LLVM
766 /// equivalent in LLVM dialect.
767 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
768 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
769 public:
770 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
771
772 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const773 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
774 ConversionPatternRewriter &rewriter) const override {
775
776 Type fromType = operation.operand().getType();
777 Type toType = operation.getType();
778
779 auto dstType = this->typeConverter.convertType(toType);
780 if (!dstType)
781 return failure();
782
783 if (getBitWidth(fromType) < getBitWidth(toType)) {
784 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
785 operands);
786 return success();
787 }
788 if (getBitWidth(fromType) > getBitWidth(toType)) {
789 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
790 operands);
791 return success();
792 }
793 return failure();
794 }
795 };
796
797 class FunctionCallPattern
798 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
799 public:
800 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
801
802 LogicalResult
matchAndRewrite(spirv::FunctionCallOp callOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const803 matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands,
804 ConversionPatternRewriter &rewriter) const override {
805 if (callOp.getNumResults() == 0) {
806 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands,
807 callOp.getAttrs());
808 return success();
809 }
810
811 // Function returns a single result.
812 auto dstType = typeConverter.convertType(callOp.getType(0));
813 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
814 callOp.getAttrs());
815 return success();
816 }
817 };
818
819 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
820 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
821 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
822 public:
823 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
824
825 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const826 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
827 ConversionPatternRewriter &rewriter) const override {
828
829 auto dstType = this->typeConverter.convertType(operation.getType());
830 if (!dstType)
831 return failure();
832
833 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
834 operation, dstType,
835 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
836 operation.operand1(), operation.operand2());
837 return success();
838 }
839 };
840
841 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
842 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
843 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
844 public:
845 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
846
847 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const848 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
849 ConversionPatternRewriter &rewriter) const override {
850
851 auto dstType = this->typeConverter.convertType(operation.getType());
852 if (!dstType)
853 return failure();
854
855 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
856 operation, dstType,
857 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
858 operation.operand1(), operation.operand2());
859 return success();
860 }
861 };
862
863 class InverseSqrtPattern
864 : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
865 public:
866 using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
867
868 LogicalResult
matchAndRewrite(spirv::GLSLInverseSqrtOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const869 matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
870 ConversionPatternRewriter &rewriter) const override {
871 auto srcType = op.getType();
872 auto dstType = typeConverter.convertType(srcType);
873 if (!dstType)
874 return failure();
875
876 Location loc = op.getLoc();
877 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
878 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
879 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
880 return success();
881 }
882 };
883
884 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
885 template <typename SPIRVop>
886 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
887 public:
888 using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion;
889
890 LogicalResult
matchAndRewrite(SPIRVop op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const891 matchAndRewrite(SPIRVop op, ArrayRef<Value> operands,
892 ConversionPatternRewriter &rewriter) const override {
893
894 if (!op.memory_access().hasValue()) {
895 replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0,
896 /*isVolatile=*/false, /*isNonTemporal=*/ false);
897 return success();
898 }
899 auto memoryAccess = op.memory_access().getValue();
900 switch (memoryAccess) {
901 case spirv::MemoryAccess::Aligned:
902 case spirv::MemoryAccess::None:
903 case spirv::MemoryAccess::Nontemporal:
904 case spirv::MemoryAccess::Volatile: {
905 unsigned alignment =
906 memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
907 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
908 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
909 replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment,
910 isVolatile, isNonTemporal);
911 return success();
912 }
913 default:
914 // There is no support of other memory access attributes.
915 return failure();
916 }
917 }
918 };
919
920 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
921 template <typename SPIRVOp>
922 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
923 public:
924 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
925
926 LogicalResult
matchAndRewrite(SPIRVOp notOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const927 matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
928 ConversionPatternRewriter &rewriter) const override {
929
930 auto srcType = notOp.getType();
931 auto dstType = this->typeConverter.convertType(srcType);
932 if (!dstType)
933 return failure();
934
935 Location loc = notOp.getLoc();
936 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
937 auto mask = srcType.template isa<VectorType>()
938 ? rewriter.create<LLVM::ConstantOp>(
939 loc, dstType,
940 SplatElementsAttr::get(
941 srcType.template cast<VectorType>(), minusOne))
942 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
943 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
944 notOp.operand(), mask);
945 return success();
946 }
947 };
948
949 /// A template pattern that erases the given `SPIRVOp`.
950 template <typename SPIRVOp>
951 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
952 public:
953 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
954
955 LogicalResult
matchAndRewrite(SPIRVOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const956 matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
957 ConversionPatternRewriter &rewriter) const override {
958 rewriter.eraseOp(op);
959 return success();
960 }
961 };
962
963 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
964 public:
965 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
966
967 LogicalResult
matchAndRewrite(spirv::ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const968 matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
969 ConversionPatternRewriter &rewriter) const override {
970 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
971 ArrayRef<Value>());
972 return success();
973 }
974 };
975
976 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
977 public:
978 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
979
980 LogicalResult
matchAndRewrite(spirv::ReturnValueOp returnValueOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const981 matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
982 ConversionPatternRewriter &rewriter) const override {
983 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
984 operands);
985 return success();
986 }
987 };
988
989 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be
990 /// reachable for conversion to succeed.
991 /// The structure of the loop in LLVM dialect will be the following:
992 ///
993 /// +------------------------------------+
994 /// | <code before spv.loop> |
995 /// | llvm.br ^header |
996 /// +------------------------------------+
997 /// |
998 /// +----------------+ |
999 /// | | |
1000 /// | V V
1001 /// | +------------------------------------+
1002 /// | | ^header: |
1003 /// | | <header code> |
1004 /// | | llvm.cond_br %cond, ^body, ^exit |
1005 /// | +------------------------------------+
1006 /// | |
1007 /// | |----------------------+
1008 /// | | |
1009 /// | V |
1010 /// | +------------------------------------+ |
1011 /// | | ^body: | |
1012 /// | | <body code> | |
1013 /// | | llvm.br ^continue | |
1014 /// | +------------------------------------+ |
1015 /// | | |
1016 /// | V |
1017 /// | +------------------------------------+ |
1018 /// | | ^continue: | |
1019 /// | | <continue code> | |
1020 /// | | llvm.br ^header | |
1021 /// | +------------------------------------+ |
1022 /// | | |
1023 /// +---------------+ +----------------------+
1024 /// |
1025 /// V
1026 /// +------------------------------------+
1027 /// | ^exit: |
1028 /// | llvm.br ^remaining |
1029 /// +------------------------------------+
1030 /// |
1031 /// V
1032 /// +------------------------------------+
1033 /// | ^remaining: |
1034 /// | <code after spv.loop> |
1035 /// +------------------------------------+
1036 ///
1037 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1038 public:
1039 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1040
1041 LogicalResult
matchAndRewrite(spirv::LoopOp loopOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1042 matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands,
1043 ConversionPatternRewriter &rewriter) const override {
1044 // There is no support of loop control at the moment.
1045 if (loopOp.loop_control() != spirv::LoopControl::None)
1046 return failure();
1047
1048 Location loc = loopOp.getLoc();
1049
1050 // Split the current block after `spv.loop`. The remaining ops will be used
1051 // in `endBlock`.
1052 Block *currentBlock = rewriter.getBlock();
1053 auto position = Block::iterator(loopOp);
1054 Block *endBlock = rewriter.splitBlock(currentBlock, position);
1055
1056 // Remove entry block and create a branch in the current block going to the
1057 // header block.
1058 Block *entryBlock = loopOp.getEntryBlock();
1059 assert(entryBlock->getOperations().size() == 1);
1060 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1061 if (!brOp)
1062 return failure();
1063 Block *headerBlock = loopOp.getHeaderBlock();
1064 rewriter.setInsertionPointToEnd(currentBlock);
1065 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1066 rewriter.eraseBlock(entryBlock);
1067
1068 // Branch from merge block to end block.
1069 Block *mergeBlock = loopOp.getMergeBlock();
1070 Operation *terminator = mergeBlock->getTerminator();
1071 ValueRange terminatorOperands = terminator->getOperands();
1072 rewriter.setInsertionPointToEnd(mergeBlock);
1073 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1074
1075 rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1076 rewriter.replaceOp(loopOp, endBlock->getArguments());
1077 return success();
1078 }
1079 };
1080
1081 /// Converts `spv.selection` with `spv.BranchConditional` in its header block.
1082 /// All blocks within selection should be reachable for conversion to succeed.
1083 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1084 public:
1085 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1086
1087 LogicalResult
matchAndRewrite(spirv::SelectionOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1088 matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
1089 ConversionPatternRewriter &rewriter) const override {
1090 // There is no support for `Flatten` or `DontFlatten` selection control at
1091 // the moment. This are just compiler hints and can be performed during the
1092 // optimization passes.
1093 if (op.selection_control() != spirv::SelectionControl::None)
1094 return failure();
1095
1096 // `spv.selection` should have at least two blocks: one selection header
1097 // block and one merge block. If no blocks are present, or control flow
1098 // branches straight to merge block (two blocks are present), the op is
1099 // redundant and it is erased.
1100 if (op.body().getBlocks().size() <= 2) {
1101 rewriter.eraseOp(op);
1102 return success();
1103 }
1104
1105 Location loc = op.getLoc();
1106
1107 // Split the current block after `spv.selection`. The remaining ops will be
1108 // used in `continueBlock`.
1109 auto *currentBlock = rewriter.getInsertionBlock();
1110 rewriter.setInsertionPointAfter(op);
1111 auto position = rewriter.getInsertionPoint();
1112 auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1113
1114 // Extract conditional branch information from the header block. By SPIR-V
1115 // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1116 // op. Note that `spv.Switch op` is not supported at the moment in the
1117 // SPIR-V dialect. Remove this block when finished.
1118 auto *headerBlock = op.getHeaderBlock();
1119 assert(headerBlock->getOperations().size() == 1);
1120 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1121 headerBlock->getOperations().front());
1122 if (!condBrOp)
1123 return failure();
1124 rewriter.eraseBlock(headerBlock);
1125
1126 // Branch from merge block to continue block.
1127 auto *mergeBlock = op.getMergeBlock();
1128 Operation *terminator = mergeBlock->getTerminator();
1129 ValueRange terminatorOperands = terminator->getOperands();
1130 rewriter.setInsertionPointToEnd(mergeBlock);
1131 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1132
1133 // Link current block to `true` and `false` blocks within the selection.
1134 Block *trueBlock = condBrOp.getTrueBlock();
1135 Block *falseBlock = condBrOp.getFalseBlock();
1136 rewriter.setInsertionPointToEnd(currentBlock);
1137 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1138 condBrOp.trueTargetOperands(), falseBlock,
1139 condBrOp.falseTargetOperands());
1140
1141 rewriter.inlineRegionBefore(op.body(), continueBlock);
1142 rewriter.replaceOp(op, continueBlock->getArguments());
1143 return success();
1144 }
1145 };
1146
1147 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1148 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1149 /// `Shift` is zero or sign extended to match this specification. Cases when
1150 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1151 template <typename SPIRVOp, typename LLVMOp>
1152 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1153 public:
1154 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1155
1156 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1157 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
1158 ConversionPatternRewriter &rewriter) const override {
1159
1160 auto dstType = this->typeConverter.convertType(operation.getType());
1161 if (!dstType)
1162 return failure();
1163
1164 Type op1Type = operation.operand1().getType();
1165 Type op2Type = operation.operand2().getType();
1166
1167 if (op1Type == op2Type) {
1168 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1169 operands);
1170 return success();
1171 }
1172
1173 Location loc = operation.getLoc();
1174 Value extended;
1175 if (isUnsignedIntegerOrVector(op2Type)) {
1176 extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1177 operation.operand2());
1178 } else {
1179 extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1180 operation.operand2());
1181 }
1182 Value result = rewriter.template create<LLVMOp>(
1183 loc, dstType, operation.operand1(), extended);
1184 rewriter.replaceOp(operation, result);
1185 return success();
1186 }
1187 };
1188
1189 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1190 public:
1191 using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1192
1193 LogicalResult
matchAndRewrite(spirv::GLSLTanOp tanOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1194 matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands,
1195 ConversionPatternRewriter &rewriter) const override {
1196 auto dstType = typeConverter.convertType(tanOp.getType());
1197 if (!dstType)
1198 return failure();
1199
1200 Location loc = tanOp.getLoc();
1201 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1202 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1203 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1204 return success();
1205 }
1206 };
1207
1208 /// Convert `spv.Tanh` to
1209 ///
1210 /// exp(2x) - 1
1211 /// -----------
1212 /// exp(2x) + 1
1213 ///
1214 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1215 public:
1216 using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1217
1218 LogicalResult
matchAndRewrite(spirv::GLSLTanhOp tanhOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1219 matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
1220 ConversionPatternRewriter &rewriter) const override {
1221 auto srcType = tanhOp.getType();
1222 auto dstType = typeConverter.convertType(srcType);
1223 if (!dstType)
1224 return failure();
1225
1226 Location loc = tanhOp.getLoc();
1227 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1228 Value multiplied =
1229 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1230 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1231 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1232 Value numerator =
1233 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1234 Value denominator =
1235 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1236 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1237 denominator);
1238 return success();
1239 }
1240 };
1241
1242 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1243 public:
1244 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1245
1246 LogicalResult
matchAndRewrite(spirv::VariableOp varOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1247 matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands,
1248 ConversionPatternRewriter &rewriter) const override {
1249 auto srcType = varOp.getType();
1250 // Initialization is supported for scalars and vectors only.
1251 auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1252 auto init = varOp.initializer();
1253 if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1254 return failure();
1255
1256 auto dstType = typeConverter.convertType(srcType);
1257 if (!dstType)
1258 return failure();
1259
1260 Location loc = varOp.getLoc();
1261 Value size = createI32ConstantOf(loc, rewriter, 1);
1262 if (!init) {
1263 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1264 return success();
1265 }
1266 Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1267 rewriter.create<LLVM::StoreOp>(loc, init, allocated);
1268 rewriter.replaceOp(varOp, allocated);
1269 return success();
1270 }
1271 };
1272
1273 //===----------------------------------------------------------------------===//
1274 // FuncOp conversion
1275 //===----------------------------------------------------------------------===//
1276
1277 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1278 public:
1279 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1280
1281 LogicalResult
matchAndRewrite(spirv::FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1282 matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
1283 ConversionPatternRewriter &rewriter) const override {
1284
1285 // Convert function signature. At the moment LLVMType converter is enough
1286 // for currently supported types.
1287 auto funcType = funcOp.getType();
1288 TypeConverter::SignatureConversion signatureConverter(
1289 funcType.getNumInputs());
1290 auto llvmType = typeConverter.convertFunctionSignature(
1291 funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1292 if (!llvmType)
1293 return failure();
1294
1295 // Create a new `LLVMFuncOp`
1296 Location loc = funcOp.getLoc();
1297 StringRef name = funcOp.getName();
1298 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1299
1300 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1301 MLIRContext *context = funcOp.getContext();
1302 switch (funcOp.function_control()) {
1303 #define DISPATCH(functionControl, llvmAttr) \
1304 case functionControl: \
1305 newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \
1306 break;
1307
1308 DISPATCH(spirv::FunctionControl::Inline,
1309 StringAttr::get("alwaysinline", context));
1310 DISPATCH(spirv::FunctionControl::DontInline,
1311 StringAttr::get("noinline", context));
1312 DISPATCH(spirv::FunctionControl::Pure,
1313 StringAttr::get("readonly", context));
1314 DISPATCH(spirv::FunctionControl::Const,
1315 StringAttr::get("readnone", context));
1316
1317 #undef DISPATCH
1318
1319 // Default: if `spirv::FunctionControl::None`, then no attributes are
1320 // needed.
1321 default:
1322 break;
1323 }
1324
1325 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1326 newFuncOp.end());
1327 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1328 &signatureConverter))) {
1329 return failure();
1330 }
1331 rewriter.eraseOp(funcOp);
1332 return success();
1333 }
1334 };
1335
1336 //===----------------------------------------------------------------------===//
1337 // ModuleOp conversion
1338 //===----------------------------------------------------------------------===//
1339
1340 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1341 public:
1342 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1343
1344 LogicalResult
matchAndRewrite(spirv::ModuleOp spvModuleOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1345 matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
1346 ConversionPatternRewriter &rewriter) const override {
1347
1348 auto newModuleOp =
1349 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1350 rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
1351
1352 // Remove the terminator block that was automatically added by builder
1353 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1354 rewriter.eraseOp(spvModuleOp);
1355 return success();
1356 }
1357 };
1358
1359 class ModuleEndConversionPattern
1360 : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
1361 public:
1362 using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
1363
1364 LogicalResult
matchAndRewrite(spirv::ModuleEndOp moduleEndOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1365 matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
1366 ConversionPatternRewriter &rewriter) const override {
1367
1368 rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
1369 return success();
1370 }
1371 };
1372
1373 } // namespace
1374
1375 //===----------------------------------------------------------------------===//
1376 // Pattern population
1377 //===----------------------------------------------------------------------===//
1378
populateSPIRVToLLVMTypeConversion(LLVMTypeConverter & typeConverter)1379 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1380 typeConverter.addConversion([&](spirv::ArrayType type) {
1381 return convertArrayType(type, typeConverter);
1382 });
1383 typeConverter.addConversion([&](spirv::PointerType type) {
1384 return convertPointerType(type, typeConverter);
1385 });
1386 typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1387 return convertRuntimeArrayType(type, typeConverter);
1388 });
1389 typeConverter.addConversion([&](spirv::StructType type) {
1390 return convertStructType(type, typeConverter);
1391 });
1392 }
1393
populateSPIRVToLLVMConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1394 void mlir::populateSPIRVToLLVMConversionPatterns(
1395 MLIRContext *context, LLVMTypeConverter &typeConverter,
1396 OwningRewritePatternList &patterns) {
1397 patterns.insert<
1398 // Arithmetic ops
1399 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1400 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1401 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1402 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1403 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1404 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1405 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1406 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1407 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1408 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1409 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1410 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1411 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1412
1413 // Bitwise ops
1414 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1415 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1416 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1417 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1418 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1419 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1420 NotPattern<spirv::NotOp>,
1421
1422 // Cast ops
1423 DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1424 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1425 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1426 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1427 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1428 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1429 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1430 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1431
1432 // Comparison ops
1433 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1434 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1435 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1436 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1437 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1438 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1439 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1440 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1441 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1442 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1443 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1444 LLVM::FCmpPredicate::uge>,
1445 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1446 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1447 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1448 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1449 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1450 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1451 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1452 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1453 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1454 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1455 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1456
1457 // Constant op
1458 ConstantScalarAndVectorPattern,
1459
1460 // Control Flow ops
1461 BranchConversionPattern, BranchConditionalConversionPattern,
1462 FunctionCallPattern, LoopPattern, SelectionPattern,
1463 ErasePattern<spirv::MergeOp>,
1464
1465 // Entry points and execution mode are handled separately.
1466 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1467
1468 // GLSL extended instruction set ops
1469 DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1470 DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1471 DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1472 DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1473 DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1474 DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1475 DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1476 DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1477 DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1478 DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1479 DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1480 DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1481 InverseSqrtPattern, TanPattern, TanhPattern,
1482
1483 // Logical ops
1484 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1485 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1486 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1487 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1488 NotPattern<spirv::LogicalNotOp>,
1489
1490 // Memory ops
1491 AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1492 LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1493 VariablePattern,
1494
1495 // Miscellaneous ops
1496 CompositeExtractPattern, CompositeInsertPattern,
1497 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1498 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1499
1500 // Shift ops
1501 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1502 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1503 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1504
1505 // Return ops
1506 ReturnPattern, ReturnValuePattern>(context, typeConverter);
1507 }
1508
populateSPIRVToLLVMFunctionConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1509 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1510 MLIRContext *context, LLVMTypeConverter &typeConverter,
1511 OwningRewritePatternList &patterns) {
1512 patterns.insert<FuncConversionPattern>(context, typeConverter);
1513 }
1514
populateSPIRVToLLVMModuleConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1515 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1516 MLIRContext *context, LLVMTypeConverter &typeConverter,
1517 OwningRewritePatternList &patterns) {
1518 patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
1519 context, typeConverter);
1520 }
1521
1522 //===----------------------------------------------------------------------===//
1523 // Pre-conversion hooks
1524 //===----------------------------------------------------------------------===//
1525
1526 /// Hook for descriptor set and binding number encoding.
1527 static constexpr StringRef kBinding = "binding";
1528 static constexpr StringRef kDescriptorSet = "descriptor_set";
encodeBindAttribute(ModuleOp module)1529 void mlir::encodeBindAttribute(ModuleOp module) {
1530 auto spvModules = module.getOps<spirv::ModuleOp>();
1531 for (auto spvModule : spvModules) {
1532 spvModule.walk([&](spirv::GlobalVariableOp op) {
1533 IntegerAttr descriptorSet =
1534 op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1535 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1536 // For every global variable in the module, get the ones with descriptor
1537 // set and binding numbers.
1538 if (descriptorSet && binding) {
1539 // Encode these numbers into the variable's symbolic name. If the
1540 // SPIR-V module has a name, add it at the beginning.
1541 auto moduleAndName = spvModule.getName().hasValue()
1542 ? spvModule.getName().getValue().str() + "_" +
1543 op.sym_name().str()
1544 : op.sym_name().str();
1545 std::string name =
1546 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1547 std::to_string(descriptorSet.getInt()),
1548 std::to_string(binding.getInt()));
1549
1550 // Replace all symbol uses and set the new symbol name. Finally, remove
1551 // descriptor set and binding attributes.
1552 if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
1553 op.emitError("unable to replace all symbol uses for ") << name;
1554 SymbolTable::setSymbolName(op, name);
1555 op.removeAttr(kDescriptorSet);
1556 op.removeAttr(kBinding);
1557 }
1558 });
1559 }
1560 }
1561