• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ConvertSPIRVToLLVM.cpp - SPIR-V dialect to LLVM dialect conversion -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
18 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
isSignedIntegerOrVector(Type type)37 static bool isSignedIntegerOrVector(Type type) {
38   if (type.isSignedInteger())
39     return true;
40   if (auto vecType = type.dyn_cast<VectorType>())
41     return vecType.getElementType().isSignedInteger();
42   return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
isUnsignedIntegerOrVector(Type type)46 static bool isUnsignedIntegerOrVector(Type type) {
47   if (type.isUnsignedInteger())
48     return true;
49   if (auto vecType = type.dyn_cast<VectorType>())
50     return vecType.getElementType().isUnsignedInteger();
51   return false;
52 }
53 
54 /// Returns the bit width of integer, float or vector of float or integer values
getBitWidth(Type type)55 static unsigned getBitWidth(Type type) {
56   assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57          "bitwidth is not supported for this type");
58   if (type.isIntOrFloat())
59     return type.getIntOrFloatBitWidth();
60   auto vecType = type.dyn_cast<VectorType>();
61   auto elementType = vecType.getElementType();
62   assert(elementType.isIntOrFloat() &&
63          "only integers and floats have a bitwidth");
64   return elementType.getIntOrFloatBitWidth();
65 }
66 
67 /// Returns the bit width of LLVMType integer or vector.
getLLVMTypeBitWidth(LLVM::LLVMType type)68 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
69   return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
70                            : type.getIntegerBitWidth();
71 }
72 
73 /// Creates `IntegerAttribute` with all bits set for given type
minusOneIntegerAttribute(Type type,Builder builder)74 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
75   if (auto vecType = type.dyn_cast<VectorType>()) {
76     auto integerType = vecType.getElementType().cast<IntegerType>();
77     return builder.getIntegerAttr(integerType, -1);
78   }
79   auto integerType = type.cast<IntegerType>();
80   return builder.getIntegerAttr(integerType, -1);
81 }
82 
83 /// Creates `llvm.mlir.constant` with all bits set for the given type.
createConstantAllBitsSet(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter)84 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
85                                       PatternRewriter &rewriter) {
86   if (srcType.isa<VectorType>()) {
87     return rewriter.create<LLVM::ConstantOp>(
88         loc, dstType,
89         SplatElementsAttr::get(srcType.cast<ShapedType>(),
90                                minusOneIntegerAttribute(srcType, rewriter)));
91   }
92   return rewriter.create<LLVM::ConstantOp>(
93       loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
94 }
95 
96 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
createFPConstant(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter,double value)97 static Value createFPConstant(Location loc, Type srcType, Type dstType,
98                               PatternRewriter &rewriter, double value) {
99   if (auto vecType = srcType.dyn_cast<VectorType>()) {
100     auto floatType = vecType.getElementType().cast<FloatType>();
101     return rewriter.create<LLVM::ConstantOp>(
102         loc, dstType,
103         SplatElementsAttr::get(vecType,
104                                rewriter.getFloatAttr(floatType, value)));
105   }
106   auto floatType = srcType.cast<FloatType>();
107   return rewriter.create<LLVM::ConstantOp>(
108       loc, dstType, rewriter.getFloatAttr(floatType, value));
109 }
110 
111 /// Utility function for bitfield ops:
112 ///   - `BitFieldInsert`
113 ///   - `BitFieldSExtract`
114 ///   - `BitFieldUExtract`
115 /// Truncates or extends the value. If the bitwidth of the value is the same as
116 /// `dstType` bitwidth, the value remains unchanged.
optionallyTruncateOrExtend(Location loc,Value value,Type dstType,PatternRewriter & rewriter)117 static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
118                                         PatternRewriter &rewriter) {
119   auto srcType = value.getType();
120   auto llvmType = dstType.cast<LLVM::LLVMType>();
121   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
122   unsigned valueBitWidth =
123       srcType.isa<LLVM::LLVMType>()
124           ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
125           : getBitWidth(srcType);
126 
127   if (valueBitWidth < targetBitWidth)
128     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
129   // If the bit widths of `Count` and `Offset` are greater than the bit width
130   // of the target type, they are truncated. Truncation is safe since `Count`
131   // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
132   // both values can be expressed in 8 bits.
133   if (valueBitWidth > targetBitWidth)
134     return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
135   return value;
136 }
137 
138 /// Broadcasts the value to vector with `numElements` number of elements.
broadcast(Location loc,Value toBroadcast,unsigned numElements,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)139 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
140                        LLVMTypeConverter &typeConverter,
141                        ConversionPatternRewriter &rewriter) {
142   auto vectorType = VectorType::get(numElements, toBroadcast.getType());
143   auto llvmVectorType = typeConverter.convertType(vectorType);
144   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
145   Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
146   for (unsigned i = 0; i < numElements; ++i) {
147     auto index = rewriter.create<LLVM::ConstantOp>(
148         loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
149     broadcasted = rewriter.create<LLVM::InsertElementOp>(
150         loc, llvmVectorType, broadcasted, toBroadcast, index);
151   }
152   return broadcasted;
153 }
154 
155 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
optionallyBroadcast(Location loc,Value value,Type srcType,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)156 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
157                                  LLVMTypeConverter &typeConverter,
158                                  ConversionPatternRewriter &rewriter) {
159   if (auto vectorType = srcType.dyn_cast<VectorType>()) {
160     unsigned numElements = vectorType.getNumElements();
161     return broadcast(loc, value, numElements, typeConverter, rewriter);
162   }
163   return value;
164 }
165 
166 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
167 /// `BitFieldUExtract`.
168 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
169 /// a vector type, construct a vector that has:
170 ///  - same number of elements as `Base`
171 ///  - each element has the type that is the same as the type of `Offset` or
172 ///    `Count`
173 ///  - each element has the same value as `Offset` or `Count`
174 /// Then cast `Offset` and `Count` if their bit width is different
175 /// from `Base` bit width.
processCountOrOffset(Location loc,Value value,Type srcType,Type dstType,LLVMTypeConverter & converter,ConversionPatternRewriter & rewriter)176 static Value processCountOrOffset(Location loc, Value value, Type srcType,
177                                   Type dstType, LLVMTypeConverter &converter,
178                                   ConversionPatternRewriter &rewriter) {
179   Value broadcasted =
180       optionallyBroadcast(loc, value, srcType, converter, rewriter);
181   return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
182 }
183 
184 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
185 /// offset to LLVM struct. Otherwise, the conversion is not supported.
186 static Optional<Type>
convertStructTypeWithOffset(spirv::StructType type,LLVMTypeConverter & converter)187 convertStructTypeWithOffset(spirv::StructType type,
188                             LLVMTypeConverter &converter) {
189   if (type != VulkanLayoutUtils::decorateType(type))
190     return llvm::None;
191 
192   auto elementsVector = llvm::to_vector<8>(
193       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
194         return converter.convertType(elementType).cast<LLVM::LLVMType>();
195       }));
196   return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
197                                      /*isPacked=*/false);
198 }
199 
200 /// Converts SPIR-V struct with no offset to packed LLVM struct.
convertStructTypePacked(spirv::StructType type,LLVMTypeConverter & converter)201 static Type convertStructTypePacked(spirv::StructType type,
202                                     LLVMTypeConverter &converter) {
203   auto elementsVector = llvm::to_vector<8>(
204       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
205         return converter.convertType(elementType).cast<LLVM::LLVMType>();
206       }));
207   return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector,
208                                      /*isPacked=*/true);
209 }
210 
211 /// Creates LLVM dialect constant with the given value.
createI32ConstantOf(Location loc,PatternRewriter & rewriter,unsigned value)212 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
213                                  unsigned value) {
214   return rewriter.create<LLVM::ConstantOp>(
215       loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()),
216       rewriter.getIntegerAttr(rewriter.getI32Type(), value));
217 }
218 
219 /// Utility for `spv.Load` and `spv.Store` conversion.
replaceWithLoadOrStore(Operation * op,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,unsigned alignment,bool isVolatile,bool isNonTemporal)220 static LogicalResult replaceWithLoadOrStore(Operation *op,
221                                             ConversionPatternRewriter &rewriter,
222                                             LLVMTypeConverter &typeConverter,
223                                             unsigned alignment, bool isVolatile,
224                                             bool isNonTemporal) {
225   if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
226     auto dstType = typeConverter.convertType(loadOp.getType());
227     if (!dstType)
228       return failure();
229     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
230         loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal);
231     return success();
232   }
233   auto storeOp = cast<spirv::StoreOp>(op);
234   rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(),
235                                              storeOp.ptr(), alignment,
236                                              isVolatile, isNonTemporal);
237   return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // Type conversion
242 //===----------------------------------------------------------------------===//
243 
244 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
245 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
246 /// when converting ops that manipulate array types.
convertArrayType(spirv::ArrayType type,TypeConverter & converter)247 static Optional<Type> convertArrayType(spirv::ArrayType type,
248                                        TypeConverter &converter) {
249   unsigned stride = type.getArrayStride();
250   Type elementType = type.getElementType();
251   auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
252   if (stride != 0 &&
253       !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
254     return llvm::None;
255 
256   auto llvmElementType =
257       converter.convertType(elementType).cast<LLVM::LLVMType>();
258   unsigned numElements = type.getNumElements();
259   return LLVM::LLVMType::getArrayTy(llvmElementType, numElements);
260 }
261 
262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
263 /// modelled at the moment.
convertPointerType(spirv::PointerType type,TypeConverter & converter)264 static Type convertPointerType(spirv::PointerType type,
265                                TypeConverter &converter) {
266   auto pointeeType =
267       converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
268   return pointeeType.getPointerTo();
269 }
270 
271 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
272 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
273 /// no modelling of array stride at the moment.
convertRuntimeArrayType(spirv::RuntimeArrayType type,TypeConverter & converter)274 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
275                                               TypeConverter &converter) {
276   if (type.getArrayStride() != 0)
277     return llvm::None;
278   auto elementType =
279       converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
280   return LLVM::LLVMType::getArrayTy(elementType, 0);
281 }
282 
283 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
284 /// member decorations. Also, only natural offset is supported.
convertStructType(spirv::StructType type,LLVMTypeConverter & converter)285 static Optional<Type> convertStructType(spirv::StructType type,
286                                         LLVMTypeConverter &converter) {
287   SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
288   type.getMemberDecorations(memberDecorations);
289   if (!memberDecorations.empty())
290     return llvm::None;
291   if (type.hasOffset())
292     return convertStructTypeWithOffset(type, converter);
293   return convertStructTypePacked(type, converter);
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // Operation conversion
298 //===----------------------------------------------------------------------===//
299 
300 namespace {
301 
302 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
303 public:
304   using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
305 
306   LogicalResult
matchAndRewrite(spirv::AccessChainOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const307   matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
308                   ConversionPatternRewriter &rewriter) const override {
309     auto dstType = typeConverter.convertType(op.component_ptr().getType());
310     if (!dstType)
311       return failure();
312     // To use GEP we need to add a first 0 index to go through the pointer.
313     auto indices = llvm::to_vector<4>(op.indices());
314     Type indexType = op.indices().front().getType();
315     auto llvmIndexType = typeConverter.convertType(indexType);
316     if (!llvmIndexType)
317       return failure();
318     Value zero = rewriter.create<LLVM::ConstantOp>(
319         op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
320     indices.insert(indices.begin(), zero);
321     rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
322                                              indices);
323     return success();
324   }
325 };
326 
327 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
328 public:
329   using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
330 
331   LogicalResult
matchAndRewrite(spirv::AddressOfOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const332   matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands,
333                   ConversionPatternRewriter &rewriter) const override {
334     auto dstType = typeConverter.convertType(op.pointer().getType());
335     if (!dstType)
336       return failure();
337     rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
338         op, dstType.cast<LLVM::LLVMType>(), op.variable());
339     return success();
340   }
341 };
342 
343 class BitFieldInsertPattern
344     : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
345 public:
346   using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
347 
348   LogicalResult
matchAndRewrite(spirv::BitFieldInsertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const349   matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
350                   ConversionPatternRewriter &rewriter) const override {
351     auto srcType = op.getType();
352     auto dstType = typeConverter.convertType(srcType);
353     if (!dstType)
354       return failure();
355     Location loc = op.getLoc();
356 
357     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
358     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
359                                         typeConverter, rewriter);
360     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
361                                        typeConverter, rewriter);
362 
363     // Create a mask with bits set outside [Offset, Offset + Count - 1].
364     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
365     Value maskShiftedByCount =
366         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
367     Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
368                                                  maskShiftedByCount, minusOne);
369     Value maskShiftedByCountAndOffset =
370         rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
371     Value mask = rewriter.create<LLVM::XOrOp>(
372         loc, dstType, maskShiftedByCountAndOffset, minusOne);
373 
374     // Extract unchanged bits from the `Base`  that are outside of
375     // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
376     Value baseAndMask =
377         rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
378     Value insertShiftedByOffset =
379         rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
380     rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
381                                             insertShiftedByOffset);
382     return success();
383   }
384 };
385 
386 /// Converts SPIR-V ConstantOp with scalar or vector type.
387 class ConstantScalarAndVectorPattern
388     : public SPIRVToLLVMConversion<spirv::ConstantOp> {
389 public:
390   using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
391 
392   LogicalResult
matchAndRewrite(spirv::ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const393   matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
394                   ConversionPatternRewriter &rewriter) const override {
395     auto srcType = constOp.getType();
396     if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
397       return failure();
398 
399     auto dstType = typeConverter.convertType(srcType);
400     if (!dstType)
401       return failure();
402 
403     // SPIR-V constant can be a signed/unsigned integer, which has to be
404     // casted to signless integer when converting to LLVM dialect. Removing the
405     // sign bit may have unexpected behaviour. However, it is better to handle
406     // it case-by-case, given that the purpose of the conversion is not to
407     // cover all possible corner cases.
408     if (isSignedIntegerOrVector(srcType) ||
409         isUnsignedIntegerOrVector(srcType)) {
410       auto *context = rewriter.getContext();
411       auto signlessType = IntegerType::get(getBitWidth(srcType), context);
412 
413       if (srcType.isa<VectorType>()) {
414         auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
415         rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
416             constOp, dstType,
417             dstElementsAttr.mapValues(
418                 signlessType, [&](const APInt &value) { return value; }));
419         return success();
420       }
421       auto srcAttr = constOp.value().cast<IntegerAttr>();
422       auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
423       rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
424       return success();
425     }
426     rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
427                                                   constOp.getAttrs());
428     return success();
429   }
430 };
431 
432 class BitFieldSExtractPattern
433     : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
434 public:
435   using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
436 
437   LogicalResult
matchAndRewrite(spirv::BitFieldSExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const438   matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
439                   ConversionPatternRewriter &rewriter) const override {
440     auto srcType = op.getType();
441     auto dstType = typeConverter.convertType(srcType);
442     if (!dstType)
443       return failure();
444     Location loc = op.getLoc();
445 
446     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
447     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
448                                         typeConverter, rewriter);
449     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
450                                        typeConverter, rewriter);
451 
452     // Create a constant that holds the size of the `Base`.
453     IntegerType integerType;
454     if (auto vecType = srcType.dyn_cast<VectorType>())
455       integerType = vecType.getElementType().cast<IntegerType>();
456     else
457       integerType = srcType.cast<IntegerType>();
458 
459     auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
460     Value size =
461         srcType.isa<VectorType>()
462             ? rewriter.create<LLVM::ConstantOp>(
463                   loc, dstType,
464                   SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
465             : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
466 
467     // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
468     // at Offset + Count - 1 is the most significant bit now.
469     Value countPlusOffset =
470         rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
471     Value amountToShiftLeft =
472         rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
473     Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
474         loc, dstType, op.base(), amountToShiftLeft);
475 
476     // Shift the result right, filling the bits with the sign bit.
477     Value amountToShiftRight =
478         rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
479     rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
480                                               amountToShiftRight);
481     return success();
482   }
483 };
484 
485 class BitFieldUExtractPattern
486     : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
487 public:
488   using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
489 
490   LogicalResult
matchAndRewrite(spirv::BitFieldUExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const491   matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
492                   ConversionPatternRewriter &rewriter) const override {
493     auto srcType = op.getType();
494     auto dstType = typeConverter.convertType(srcType);
495     if (!dstType)
496       return failure();
497     Location loc = op.getLoc();
498 
499     // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
500     Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
501                                         typeConverter, rewriter);
502     Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
503                                        typeConverter, rewriter);
504 
505     // Create a mask with bits set at [0, Count - 1].
506     Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
507     Value maskShiftedByCount =
508         rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
509     Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
510                                               minusOne);
511 
512     // Shift `Base` by `Offset` and apply the mask on it.
513     Value shiftedBase =
514         rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
515     rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
516     return success();
517   }
518 };
519 
520 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
521 public:
522   using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
523 
524   LogicalResult
matchAndRewrite(spirv::BranchOp branchOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const525   matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands,
526                   ConversionPatternRewriter &rewriter) const override {
527     rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands,
528                                             branchOp.getTarget());
529     return success();
530   }
531 };
532 
533 class BranchConditionalConversionPattern
534     : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
535 public:
536   using SPIRVToLLVMConversion<
537       spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
538 
539   LogicalResult
matchAndRewrite(spirv::BranchConditionalOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const540   matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
541                   ConversionPatternRewriter &rewriter) const override {
542     // If branch weights exist, map them to 32-bit integer vector.
543     ElementsAttr branchWeights = nullptr;
544     if (auto weights = op.branch_weights()) {
545       VectorType weightType = VectorType::get(2, rewriter.getI32Type());
546       branchWeights =
547           DenseElementsAttr::get(weightType, weights.getValue().getValue());
548     }
549 
550     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
551         op, op.condition(), op.getTrueBlockArguments(),
552         op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
553         op.getFalseBlock());
554     return success();
555   }
556 };
557 
558 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
559 /// is an aggregate type (struct or array). Otherwise, converts to
560 /// `llvm.extractelement` that operates on vectors.
561 class CompositeExtractPattern
562     : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
563 public:
564   using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
565 
566   LogicalResult
matchAndRewrite(spirv::CompositeExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const567   matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands,
568                   ConversionPatternRewriter &rewriter) const override {
569     auto dstType = this->typeConverter.convertType(op.getType());
570     if (!dstType)
571       return failure();
572 
573     Type containerType = op.composite().getType();
574     if (containerType.isa<VectorType>()) {
575       Location loc = op.getLoc();
576       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
577       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
578       rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
579           op, dstType, op.composite(), index);
580       return success();
581     }
582     rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
583         op, dstType, op.composite(), op.indices());
584     return success();
585   }
586 };
587 
588 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
589 /// is an aggregate type (struct or array). Otherwise, converts to
590 /// `llvm.insertelement` that operates on vectors.
591 class CompositeInsertPattern
592     : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
593 public:
594   using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
595 
596   LogicalResult
matchAndRewrite(spirv::CompositeInsertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const597   matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands,
598                   ConversionPatternRewriter &rewriter) const override {
599     auto dstType = this->typeConverter.convertType(op.getType());
600     if (!dstType)
601       return failure();
602 
603     Type containerType = op.composite().getType();
604     if (containerType.isa<VectorType>()) {
605       Location loc = op.getLoc();
606       IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
607       Value index = createI32ConstantOf(loc, rewriter, value.getInt());
608       rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
609           op, dstType, op.composite(), op.object(), index);
610       return success();
611     }
612     rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
613         op, dstType, op.composite(), op.object(), op.indices());
614     return success();
615   }
616 };
617 
618 /// Converts SPIR-V operations that have straightforward LLVM equivalent
619 /// into LLVM dialect operations.
620 template <typename SPIRVOp, typename LLVMOp>
621 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
622 public:
623   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
624 
625   LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const626   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
627                   ConversionPatternRewriter &rewriter) const override {
628     auto dstType = this->typeConverter.convertType(operation.getType());
629     if (!dstType)
630       return failure();
631     rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands,
632                                                  operation.getAttrs());
633     return success();
634   }
635 };
636 
637 /// Converts `spv.ExecutionMode` into a global struct constant that holds
638 /// execution mode information.
639 class ExecutionModePattern
640     : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
641 public:
642   using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
643 
644   LogicalResult
matchAndRewrite(spirv::ExecutionModeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const645   matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands,
646                   ConversionPatternRewriter &rewriter) const override {
647     // First, create the global struct's name that would be associated with
648     // this entry point's execution mode. We set it to be:
649     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info
650     ModuleOp module = op->getParentOfType<ModuleOp>();
651     std::string moduleName;
652     if (module.getName().hasValue())
653       moduleName = "_" + module.getName().getValue().str();
654     else
655       moduleName = "";
656     std::string executionModeInfoName = llvm::formatv(
657         "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str());
658 
659     MLIRContext *context = rewriter.getContext();
660     OpBuilder::InsertionGuard guard(rewriter);
661     rewriter.setInsertionPointToStart(module.getBody());
662 
663     // Create a struct type, corresponding to the C struct below.
664     // struct {
665     //   int32_t executionMode;
666     //   int32_t values[];          // optional values
667     // };
668     auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context);
669     SmallVector<LLVM::LLVMType, 2> fields;
670     fields.push_back(llvmI32Type);
671     ArrayAttr values = op.values();
672     if (!values.empty()) {
673       auto arrayType = LLVM::LLVMType::getArrayTy(llvmI32Type, values.size());
674       fields.push_back(arrayType);
675     }
676     auto structType = LLVM::LLVMType::getStructTy(context, fields);
677 
678     // Create `llvm.mlir.global` with initializer region containing one block.
679     auto global = rewriter.create<LLVM::GlobalOp>(
680         UnknownLoc::get(context), structType, /*isConstant=*/true,
681         LLVM::Linkage::External, executionModeInfoName, Attribute());
682     Location loc = global.getLoc();
683     Region &region = global.getInitializerRegion();
684     Block *block = rewriter.createBlock(&region);
685 
686     // Initialize the struct and set the execution mode value.
687     rewriter.setInsertionPoint(block, block->begin());
688     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
689     IntegerAttr executionModeAttr = op.execution_modeAttr();
690     Value executionMode =
691         rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
692     structValue = rewriter.create<LLVM::InsertValueOp>(
693         loc, structType, structValue, executionMode,
694         ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
695                        context));
696 
697     // Insert extra operands if they exist into execution mode info struct.
698     for (unsigned i = 0, e = values.size(); i < e; ++i) {
699       auto attr = values.getValue()[i];
700       Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
701       structValue = rewriter.create<LLVM::InsertValueOp>(
702           loc, structType, structValue, entry,
703           ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
704                           rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
705                          context));
706     }
707     rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
708     rewriter.eraseOp(op);
709     return success();
710   }
711 };
712 
713 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global
714 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
715 /// This difference is handled by `spv.mlir.addressof` and
716 /// `llvm.mlir.addressof`ops that both return a pointer.
717 class GlobalVariablePattern
718     : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
719 public:
720   using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
721 
722   LogicalResult
matchAndRewrite(spirv::GlobalVariableOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const723   matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands,
724                   ConversionPatternRewriter &rewriter) const override {
725     // Currently, there is no support of initialization with a constant value in
726     // SPIR-V dialect. Specialization constants are not considered as well.
727     if (op.initializer())
728       return failure();
729 
730     auto srcType = op.type().cast<spirv::PointerType>();
731     auto dstType = typeConverter.convertType(srcType.getPointeeType());
732     if (!dstType)
733       return failure();
734 
735     // Limit conversion to the current invocation only or `StorageBuffer`
736     // required by SPIR-V runner.
737     // This is okay because multiple invocations are not supported yet.
738     auto storageClass = srcType.getStorageClass();
739     if (storageClass != spirv::StorageClass::Input &&
740         storageClass != spirv::StorageClass::Private &&
741         storageClass != spirv::StorageClass::Output &&
742         storageClass != spirv::StorageClass::StorageBuffer) {
743       return failure();
744     }
745 
746     // LLVM dialect spec: "If the global value is a constant, storing into it is
747     // not allowed.". This corresponds to SPIR-V 'Input' storage class that is
748     // read-only.
749     bool isConstant = storageClass == spirv::StorageClass::Input;
750     // SPIR-V spec: "By default, functions and global variables are private to a
751     // module and cannot be accessed by other modules. However, a module may be
752     // written to export or import functions and global (module scope)
753     // variables.". Therefore, map 'Private' storage class to private linkage,
754     // 'Input' and 'Output' to external linkage.
755     auto linkage = storageClass == spirv::StorageClass::Private
756                        ? LLVM::Linkage::Private
757                        : LLVM::Linkage::External;
758     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
759         op, dstType.cast<LLVM::LLVMType>(), isConstant, linkage, op.sym_name(),
760         Attribute());
761     return success();
762   }
763 };
764 
765 /// Converts SPIR-V cast ops that do not have straightforward LLVM
766 /// equivalent in LLVM dialect.
767 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
768 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
769 public:
770   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
771 
772   LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const773   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
774                   ConversionPatternRewriter &rewriter) const override {
775 
776     Type fromType = operation.operand().getType();
777     Type toType = operation.getType();
778 
779     auto dstType = this->typeConverter.convertType(toType);
780     if (!dstType)
781       return failure();
782 
783     if (getBitWidth(fromType) < getBitWidth(toType)) {
784       rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
785                                                       operands);
786       return success();
787     }
788     if (getBitWidth(fromType) > getBitWidth(toType)) {
789       rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
790                                                         operands);
791       return success();
792     }
793     return failure();
794   }
795 };
796 
797 class FunctionCallPattern
798     : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
799 public:
800   using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
801 
802   LogicalResult
matchAndRewrite(spirv::FunctionCallOp callOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const803   matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands,
804                   ConversionPatternRewriter &rewriter) const override {
805     if (callOp.getNumResults() == 0) {
806       rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands,
807                                                 callOp.getAttrs());
808       return success();
809     }
810 
811     // Function returns a single result.
812     auto dstType = typeConverter.convertType(callOp.getType(0));
813     rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
814                                               callOp.getAttrs());
815     return success();
816   }
817 };
818 
819 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
820 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
821 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
822 public:
823   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
824 
825   LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const826   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
827                   ConversionPatternRewriter &rewriter) const override {
828 
829     auto dstType = this->typeConverter.convertType(operation.getType());
830     if (!dstType)
831       return failure();
832 
833     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
834         operation, dstType,
835         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
836         operation.operand1(), operation.operand2());
837     return success();
838   }
839 };
840 
841 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
842 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
843 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
844 public:
845   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
846 
847   LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const848   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
849                   ConversionPatternRewriter &rewriter) const override {
850 
851     auto dstType = this->typeConverter.convertType(operation.getType());
852     if (!dstType)
853       return failure();
854 
855     rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
856         operation, dstType,
857         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
858         operation.operand1(), operation.operand2());
859     return success();
860   }
861 };
862 
863 class InverseSqrtPattern
864     : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
865 public:
866   using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
867 
868   LogicalResult
matchAndRewrite(spirv::GLSLInverseSqrtOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const869   matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
870                   ConversionPatternRewriter &rewriter) const override {
871     auto srcType = op.getType();
872     auto dstType = typeConverter.convertType(srcType);
873     if (!dstType)
874       return failure();
875 
876     Location loc = op.getLoc();
877     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
878     Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
879     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
880     return success();
881   }
882 };
883 
884 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
885 template <typename SPIRVop>
886 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
887 public:
888   using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion;
889 
890   LogicalResult
matchAndRewrite(SPIRVop op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const891   matchAndRewrite(SPIRVop op, ArrayRef<Value> operands,
892                   ConversionPatternRewriter &rewriter) const override {
893 
894     if (!op.memory_access().hasValue()) {
895       replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0,
896                              /*isVolatile=*/false, /*isNonTemporal=*/ false);
897       return success();
898     }
899     auto memoryAccess = op.memory_access().getValue();
900     switch (memoryAccess) {
901     case spirv::MemoryAccess::Aligned:
902     case spirv::MemoryAccess::None:
903     case spirv::MemoryAccess::Nontemporal:
904     case spirv::MemoryAccess::Volatile: {
905       unsigned alignment =
906           memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
907       bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
908       bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
909       replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment,
910                              isVolatile, isNonTemporal);
911       return success();
912     }
913     default:
914       // There is no support of other memory access attributes.
915       return failure();
916     }
917   }
918 };
919 
920 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
921 template <typename SPIRVOp>
922 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
923 public:
924   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
925 
926   LogicalResult
matchAndRewrite(SPIRVOp notOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const927   matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
928                   ConversionPatternRewriter &rewriter) const override {
929 
930     auto srcType = notOp.getType();
931     auto dstType = this->typeConverter.convertType(srcType);
932     if (!dstType)
933       return failure();
934 
935     Location loc = notOp.getLoc();
936     IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
937     auto mask = srcType.template isa<VectorType>()
938                     ? rewriter.create<LLVM::ConstantOp>(
939                           loc, dstType,
940                           SplatElementsAttr::get(
941                               srcType.template cast<VectorType>(), minusOne))
942                     : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
943     rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
944                                                       notOp.operand(), mask);
945     return success();
946   }
947 };
948 
949 /// A template pattern that erases the given `SPIRVOp`.
950 template <typename SPIRVOp>
951 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
952 public:
953   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
954 
955   LogicalResult
matchAndRewrite(SPIRVOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const956   matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
957                   ConversionPatternRewriter &rewriter) const override {
958     rewriter.eraseOp(op);
959     return success();
960   }
961 };
962 
963 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
964 public:
965   using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
966 
967   LogicalResult
matchAndRewrite(spirv::ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const968   matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
969                   ConversionPatternRewriter &rewriter) const override {
970     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
971                                                 ArrayRef<Value>());
972     return success();
973   }
974 };
975 
976 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
977 public:
978   using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
979 
980   LogicalResult
matchAndRewrite(spirv::ReturnValueOp returnValueOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const981   matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
982                   ConversionPatternRewriter &rewriter) const override {
983     rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
984                                                 operands);
985     return success();
986   }
987 };
988 
989 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be
990 /// reachable for conversion to succeed.
991 /// The structure of the loop in LLVM dialect will be the following:
992 ///
993 ///      +------------------------------------+
994 ///      | <code before spv.loop>             |
995 ///      | llvm.br ^header                    |
996 ///      +------------------------------------+
997 ///                           |
998 ///   +----------------+      |
999 ///   |                |      |
1000 ///   |                V      V
1001 ///   |  +------------------------------------+
1002 ///   |  | ^header:                           |
1003 ///   |  |   <header code>                    |
1004 ///   |  |   llvm.cond_br %cond, ^body, ^exit |
1005 ///   |  +------------------------------------+
1006 ///   |                    |
1007 ///   |                    |----------------------+
1008 ///   |                    |                      |
1009 ///   |                    V                      |
1010 ///   |  +------------------------------------+   |
1011 ///   |  | ^body:                             |   |
1012 ///   |  |   <body code>                      |   |
1013 ///   |  |   llvm.br ^continue                |   |
1014 ///   |  +------------------------------------+   |
1015 ///   |                    |                      |
1016 ///   |                    V                      |
1017 ///   |  +------------------------------------+   |
1018 ///   |  | ^continue:                         |   |
1019 ///   |  |   <continue code>                  |   |
1020 ///   |  |   llvm.br ^header                  |   |
1021 ///   |  +------------------------------------+   |
1022 ///   |               |                           |
1023 ///   +---------------+    +----------------------+
1024 ///                        |
1025 ///                        V
1026 ///      +------------------------------------+
1027 ///      | ^exit:                             |
1028 ///      |   llvm.br ^remaining               |
1029 ///      +------------------------------------+
1030 ///                        |
1031 ///                        V
1032 ///      +------------------------------------+
1033 ///      | ^remaining:                        |
1034 ///      |   <code after spv.loop>            |
1035 ///      +------------------------------------+
1036 ///
1037 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1038 public:
1039   using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1040 
1041   LogicalResult
matchAndRewrite(spirv::LoopOp loopOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1042   matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands,
1043                   ConversionPatternRewriter &rewriter) const override {
1044     // There is no support of loop control at the moment.
1045     if (loopOp.loop_control() != spirv::LoopControl::None)
1046       return failure();
1047 
1048     Location loc = loopOp.getLoc();
1049 
1050     // Split the current block after `spv.loop`. The remaining ops will be used
1051     // in `endBlock`.
1052     Block *currentBlock = rewriter.getBlock();
1053     auto position = Block::iterator(loopOp);
1054     Block *endBlock = rewriter.splitBlock(currentBlock, position);
1055 
1056     // Remove entry block and create a branch in the current block going to the
1057     // header block.
1058     Block *entryBlock = loopOp.getEntryBlock();
1059     assert(entryBlock->getOperations().size() == 1);
1060     auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1061     if (!brOp)
1062       return failure();
1063     Block *headerBlock = loopOp.getHeaderBlock();
1064     rewriter.setInsertionPointToEnd(currentBlock);
1065     rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1066     rewriter.eraseBlock(entryBlock);
1067 
1068     // Branch from merge block to end block.
1069     Block *mergeBlock = loopOp.getMergeBlock();
1070     Operation *terminator = mergeBlock->getTerminator();
1071     ValueRange terminatorOperands = terminator->getOperands();
1072     rewriter.setInsertionPointToEnd(mergeBlock);
1073     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1074 
1075     rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1076     rewriter.replaceOp(loopOp, endBlock->getArguments());
1077     return success();
1078   }
1079 };
1080 
1081 /// Converts `spv.selection` with `spv.BranchConditional` in its header block.
1082 /// All blocks within selection should be reachable for conversion to succeed.
1083 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1084 public:
1085   using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1086 
1087   LogicalResult
matchAndRewrite(spirv::SelectionOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1088   matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
1089                   ConversionPatternRewriter &rewriter) const override {
1090     // There is no support for `Flatten` or `DontFlatten` selection control at
1091     // the moment. This are just compiler hints and can be performed during the
1092     // optimization passes.
1093     if (op.selection_control() != spirv::SelectionControl::None)
1094       return failure();
1095 
1096     // `spv.selection` should have at least two blocks: one selection header
1097     // block and one merge block. If no blocks are present, or control flow
1098     // branches straight to merge block (two blocks are present), the op is
1099     // redundant and it is erased.
1100     if (op.body().getBlocks().size() <= 2) {
1101       rewriter.eraseOp(op);
1102       return success();
1103     }
1104 
1105     Location loc = op.getLoc();
1106 
1107     // Split the current block after `spv.selection`. The remaining ops will be
1108     // used in `continueBlock`.
1109     auto *currentBlock = rewriter.getInsertionBlock();
1110     rewriter.setInsertionPointAfter(op);
1111     auto position = rewriter.getInsertionPoint();
1112     auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1113 
1114     // Extract conditional branch information from the header block. By SPIR-V
1115     // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1116     // op. Note that `spv.Switch op` is not supported at the moment in the
1117     // SPIR-V dialect. Remove this block when finished.
1118     auto *headerBlock = op.getHeaderBlock();
1119     assert(headerBlock->getOperations().size() == 1);
1120     auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1121         headerBlock->getOperations().front());
1122     if (!condBrOp)
1123       return failure();
1124     rewriter.eraseBlock(headerBlock);
1125 
1126     // Branch from merge block to continue block.
1127     auto *mergeBlock = op.getMergeBlock();
1128     Operation *terminator = mergeBlock->getTerminator();
1129     ValueRange terminatorOperands = terminator->getOperands();
1130     rewriter.setInsertionPointToEnd(mergeBlock);
1131     rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1132 
1133     // Link current block to `true` and `false` blocks within the selection.
1134     Block *trueBlock = condBrOp.getTrueBlock();
1135     Block *falseBlock = condBrOp.getFalseBlock();
1136     rewriter.setInsertionPointToEnd(currentBlock);
1137     rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1138                                     condBrOp.trueTargetOperands(), falseBlock,
1139                                     condBrOp.falseTargetOperands());
1140 
1141     rewriter.inlineRegionBefore(op.body(), continueBlock);
1142     rewriter.replaceOp(op, continueBlock->getArguments());
1143     return success();
1144   }
1145 };
1146 
1147 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1148 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1149 /// `Shift` is zero or sign extended to match this specification. Cases when
1150 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1151 template <typename SPIRVOp, typename LLVMOp>
1152 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1153 public:
1154   using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1155 
1156   LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1157   matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
1158                   ConversionPatternRewriter &rewriter) const override {
1159 
1160     auto dstType = this->typeConverter.convertType(operation.getType());
1161     if (!dstType)
1162       return failure();
1163 
1164     Type op1Type = operation.operand1().getType();
1165     Type op2Type = operation.operand2().getType();
1166 
1167     if (op1Type == op2Type) {
1168       rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1169                                                    operands);
1170       return success();
1171     }
1172 
1173     Location loc = operation.getLoc();
1174     Value extended;
1175     if (isUnsignedIntegerOrVector(op2Type)) {
1176       extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1177                                                         operation.operand2());
1178     } else {
1179       extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1180                                                         operation.operand2());
1181     }
1182     Value result = rewriter.template create<LLVMOp>(
1183         loc, dstType, operation.operand1(), extended);
1184     rewriter.replaceOp(operation, result);
1185     return success();
1186   }
1187 };
1188 
1189 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1190 public:
1191   using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1192 
1193   LogicalResult
matchAndRewrite(spirv::GLSLTanOp tanOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1194   matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands,
1195                   ConversionPatternRewriter &rewriter) const override {
1196     auto dstType = typeConverter.convertType(tanOp.getType());
1197     if (!dstType)
1198       return failure();
1199 
1200     Location loc = tanOp.getLoc();
1201     Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1202     Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1203     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1204     return success();
1205   }
1206 };
1207 
1208 /// Convert `spv.Tanh` to
1209 ///
1210 ///   exp(2x) - 1
1211 ///   -----------
1212 ///   exp(2x) + 1
1213 ///
1214 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1215 public:
1216   using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1217 
1218   LogicalResult
matchAndRewrite(spirv::GLSLTanhOp tanhOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1219   matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
1220                   ConversionPatternRewriter &rewriter) const override {
1221     auto srcType = tanhOp.getType();
1222     auto dstType = typeConverter.convertType(srcType);
1223     if (!dstType)
1224       return failure();
1225 
1226     Location loc = tanhOp.getLoc();
1227     Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1228     Value multiplied =
1229         rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1230     Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1231     Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1232     Value numerator =
1233         rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1234     Value denominator =
1235         rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1236     rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1237                                               denominator);
1238     return success();
1239   }
1240 };
1241 
1242 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1243 public:
1244   using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1245 
1246   LogicalResult
matchAndRewrite(spirv::VariableOp varOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1247   matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands,
1248                   ConversionPatternRewriter &rewriter) const override {
1249     auto srcType = varOp.getType();
1250     // Initialization is supported for scalars and vectors only.
1251     auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1252     auto init = varOp.initializer();
1253     if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1254       return failure();
1255 
1256     auto dstType = typeConverter.convertType(srcType);
1257     if (!dstType)
1258       return failure();
1259 
1260     Location loc = varOp.getLoc();
1261     Value size = createI32ConstantOf(loc, rewriter, 1);
1262     if (!init) {
1263       rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1264       return success();
1265     }
1266     Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1267     rewriter.create<LLVM::StoreOp>(loc, init, allocated);
1268     rewriter.replaceOp(varOp, allocated);
1269     return success();
1270   }
1271 };
1272 
1273 //===----------------------------------------------------------------------===//
1274 // FuncOp conversion
1275 //===----------------------------------------------------------------------===//
1276 
1277 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1278 public:
1279   using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1280 
1281   LogicalResult
matchAndRewrite(spirv::FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1282   matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
1283                   ConversionPatternRewriter &rewriter) const override {
1284 
1285     // Convert function signature. At the moment LLVMType converter is enough
1286     // for currently supported types.
1287     auto funcType = funcOp.getType();
1288     TypeConverter::SignatureConversion signatureConverter(
1289         funcType.getNumInputs());
1290     auto llvmType = typeConverter.convertFunctionSignature(
1291         funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1292     if (!llvmType)
1293       return failure();
1294 
1295     // Create a new `LLVMFuncOp`
1296     Location loc = funcOp.getLoc();
1297     StringRef name = funcOp.getName();
1298     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1299 
1300     // Convert SPIR-V Function Control to equivalent LLVM function attribute
1301     MLIRContext *context = funcOp.getContext();
1302     switch (funcOp.function_control()) {
1303 #define DISPATCH(functionControl, llvmAttr)                                    \
1304   case functionControl:                                                        \
1305     newFuncOp.setAttr("passthrough", ArrayAttr::get({llvmAttr}, context));     \
1306     break;
1307 
1308           DISPATCH(spirv::FunctionControl::Inline,
1309                    StringAttr::get("alwaysinline", context));
1310           DISPATCH(spirv::FunctionControl::DontInline,
1311                    StringAttr::get("noinline", context));
1312           DISPATCH(spirv::FunctionControl::Pure,
1313                    StringAttr::get("readonly", context));
1314           DISPATCH(spirv::FunctionControl::Const,
1315                    StringAttr::get("readnone", context));
1316 
1317 #undef DISPATCH
1318 
1319     // Default: if `spirv::FunctionControl::None`, then no attributes are
1320     // needed.
1321     default:
1322       break;
1323     }
1324 
1325     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1326                                 newFuncOp.end());
1327     if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1328                                            &signatureConverter))) {
1329       return failure();
1330     }
1331     rewriter.eraseOp(funcOp);
1332     return success();
1333   }
1334 };
1335 
1336 //===----------------------------------------------------------------------===//
1337 // ModuleOp conversion
1338 //===----------------------------------------------------------------------===//
1339 
1340 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1341 public:
1342   using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1343 
1344   LogicalResult
matchAndRewrite(spirv::ModuleOp spvModuleOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1345   matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
1346                   ConversionPatternRewriter &rewriter) const override {
1347 
1348     auto newModuleOp =
1349         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1350     rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
1351 
1352     // Remove the terminator block that was automatically added by builder
1353     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1354     rewriter.eraseOp(spvModuleOp);
1355     return success();
1356   }
1357 };
1358 
1359 class ModuleEndConversionPattern
1360     : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
1361 public:
1362   using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
1363 
1364   LogicalResult
matchAndRewrite(spirv::ModuleEndOp moduleEndOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1365   matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
1366                   ConversionPatternRewriter &rewriter) const override {
1367 
1368     rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
1369     return success();
1370   }
1371 };
1372 
1373 } // namespace
1374 
1375 //===----------------------------------------------------------------------===//
1376 // Pattern population
1377 //===----------------------------------------------------------------------===//
1378 
populateSPIRVToLLVMTypeConversion(LLVMTypeConverter & typeConverter)1379 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1380   typeConverter.addConversion([&](spirv::ArrayType type) {
1381     return convertArrayType(type, typeConverter);
1382   });
1383   typeConverter.addConversion([&](spirv::PointerType type) {
1384     return convertPointerType(type, typeConverter);
1385   });
1386   typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1387     return convertRuntimeArrayType(type, typeConverter);
1388   });
1389   typeConverter.addConversion([&](spirv::StructType type) {
1390     return convertStructType(type, typeConverter);
1391   });
1392 }
1393 
populateSPIRVToLLVMConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1394 void mlir::populateSPIRVToLLVMConversionPatterns(
1395     MLIRContext *context, LLVMTypeConverter &typeConverter,
1396     OwningRewritePatternList &patterns) {
1397   patterns.insert<
1398       // Arithmetic ops
1399       DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1400       DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1401       DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1402       DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1403       DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1404       DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1405       DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1406       DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1407       DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1408       DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1409       DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1410       DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1411       DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1412 
1413       // Bitwise ops
1414       BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1415       DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1416       DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1417       DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1418       DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1419       DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1420       NotPattern<spirv::NotOp>,
1421 
1422       // Cast ops
1423       DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1424       DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1425       DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1426       DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1427       DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1428       IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1429       IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1430       IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1431 
1432       // Comparison ops
1433       IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1434       IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1435       FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1436       FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1437       FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1438       FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1439       FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1440       FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1441       FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1442       FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1443       FComparePattern<spirv::FUnordGreaterThanEqualOp,
1444                       LLVM::FCmpPredicate::uge>,
1445       FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1446       FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1447       FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1448       IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1449       IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1450       IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1451       IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1452       IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1453       IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1454       IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1455       IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1456 
1457       // Constant op
1458       ConstantScalarAndVectorPattern,
1459 
1460       // Control Flow ops
1461       BranchConversionPattern, BranchConditionalConversionPattern,
1462       FunctionCallPattern, LoopPattern, SelectionPattern,
1463       ErasePattern<spirv::MergeOp>,
1464 
1465       // Entry points and execution mode are handled separately.
1466       ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1467 
1468       // GLSL extended instruction set ops
1469       DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1470       DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1471       DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1472       DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1473       DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1474       DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1475       DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1476       DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1477       DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1478       DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1479       DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1480       DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1481       InverseSqrtPattern, TanPattern, TanhPattern,
1482 
1483       // Logical ops
1484       DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1485       DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1486       IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1487       IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1488       NotPattern<spirv::LogicalNotOp>,
1489 
1490       // Memory ops
1491       AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1492       LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1493       VariablePattern,
1494 
1495       // Miscellaneous ops
1496       CompositeExtractPattern, CompositeInsertPattern,
1497       DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1498       DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1499 
1500       // Shift ops
1501       ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1502       ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1503       ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1504 
1505       // Return ops
1506       ReturnPattern, ReturnValuePattern>(context, typeConverter);
1507 }
1508 
populateSPIRVToLLVMFunctionConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1509 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1510     MLIRContext *context, LLVMTypeConverter &typeConverter,
1511     OwningRewritePatternList &patterns) {
1512   patterns.insert<FuncConversionPattern>(context, typeConverter);
1513 }
1514 
populateSPIRVToLLVMModuleConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1515 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1516     MLIRContext *context, LLVMTypeConverter &typeConverter,
1517     OwningRewritePatternList &patterns) {
1518   patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
1519       context, typeConverter);
1520 }
1521 
1522 //===----------------------------------------------------------------------===//
1523 // Pre-conversion hooks
1524 //===----------------------------------------------------------------------===//
1525 
1526 /// Hook for descriptor set and binding number encoding.
1527 static constexpr StringRef kBinding = "binding";
1528 static constexpr StringRef kDescriptorSet = "descriptor_set";
encodeBindAttribute(ModuleOp module)1529 void mlir::encodeBindAttribute(ModuleOp module) {
1530   auto spvModules = module.getOps<spirv::ModuleOp>();
1531   for (auto spvModule : spvModules) {
1532     spvModule.walk([&](spirv::GlobalVariableOp op) {
1533       IntegerAttr descriptorSet =
1534           op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1535       IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1536       // For every global variable in the module, get the ones with descriptor
1537       // set and binding numbers.
1538       if (descriptorSet && binding) {
1539         // Encode these numbers into the variable's symbolic name. If the
1540         // SPIR-V module has a name, add it at the beginning.
1541         auto moduleAndName = spvModule.getName().hasValue()
1542                                  ? spvModule.getName().getValue().str() + "_" +
1543                                        op.sym_name().str()
1544                                  : op.sym_name().str();
1545         std::string name =
1546             llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1547                           std::to_string(descriptorSet.getInt()),
1548                           std::to_string(binding.getInt()));
1549 
1550         // Replace all symbol uses and set the new symbol name. Finally, remove
1551         // descriptor set and binding attributes.
1552         if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
1553           op.emitError("unable to replace all symbol uses for ") << name;
1554         SymbolTable::setSymbolName(op, name);
1555         op.removeAttr(kDescriptorSet);
1556         op.removeAttr(kBinding);
1557       }
1558     });
1559   }
1560 }
1561