1 //===- SPIRVLowering.h - SPIR-V lowering utilities -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines utilities to use while targeting SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H 14 #define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H 15 16 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" 17 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/TargetAndABI.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/ADT/SmallSet.h" 21 22 namespace mlir { 23 24 /// Type conversion from builtin types to SPIR-V types for shader interface. 25 /// 26 /// Non-32-bit scalar types require special hardware support that may not exist 27 /// on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types 28 /// require special capabilities or extensions. Right now if a scalar type of a 29 /// certain bitwidth is not supported in the target environment, we use 32-bit 30 /// ones unconditionally. This requires the runtime to also feed in data with 31 /// a matched bitwidth and layout for interface types. The runtime can do that 32 /// by inspecting the SPIR-V module. 33 /// 34 /// For memref types, this converter additionally performs type wrapping to 35 /// satisfy shader interface requirements: shader interface types must be 36 /// pointers to structs. 37 /// 38 /// TODO: We might want to introduce a way to control how unsupported bitwidth 39 /// are handled and explicitly fail if wanted. 40 class SPIRVTypeConverter : public TypeConverter { 41 public: 42 explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr); 43 44 /// Gets the number of bytes used for a type when converted to SPIR-V 45 /// type. Note that it doesnt account for whether the type is legal for a 46 /// SPIR-V target (described by spirv::TargetEnvAttr). Returns None on 47 /// failure. 48 static Optional<int64_t> getConvertedTypeNumBytes(Type); 49 50 /// Gets the SPIR-V correspondence for the standard index type. 51 static Type getIndexType(MLIRContext *context); 52 53 /// Returns the corresponding memory space for memref given a SPIR-V storage 54 /// class. 55 static unsigned getMemorySpaceForStorageClass(spirv::StorageClass); 56 57 /// Returns the SPIR-V storage class given a memory space for memref. Return 58 /// llvm::None if the memory space does not map to any SPIR-V storage class. 59 static Optional<spirv::StorageClass> 60 getStorageClassForMemorySpace(unsigned space); 61 62 private: 63 spirv::TargetEnv targetEnv; 64 }; 65 66 /// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V. 67 template <typename SourceOp> 68 class SPIRVOpLowering : public OpConversionPattern<SourceOp> { 69 public: 70 SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter, 71 PatternBenefit benefit = 1) 72 : OpConversionPattern<SourceOp>(context, benefit), 73 typeConverter(typeConverter) {} 74 75 protected: 76 SPIRVTypeConverter &typeConverter; 77 }; 78 79 /// Appends to a pattern list additional patterns for translating the builtin 80 /// `func` op to the SPIR-V dialect. These patterns do not handle shader 81 /// interface/ABI; they convert function parameters to be of SPIR-V allowed 82 /// types. 83 void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context, 84 SPIRVTypeConverter &typeConverter, 85 OwningRewritePatternList &patterns); 86 87 namespace spirv { 88 class AccessChainOp; 89 class FuncOp; 90 91 class SPIRVConversionTarget : public ConversionTarget { 92 public: 93 /// Creates a SPIR-V conversion target for the given target environment. 94 static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr); 95 96 private: 97 explicit SPIRVConversionTarget(TargetEnvAttr targetAttr); 98 99 // Be explicit that instance of this class cannot be copied or moved: there 100 // are lambdas capturing fields of the instance. 101 SPIRVConversionTarget(const SPIRVConversionTarget &) = delete; 102 SPIRVConversionTarget(SPIRVConversionTarget &&) = delete; 103 SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete; 104 SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete; 105 106 /// Returns true if the given `op` is legal to use under the current target 107 /// environment. 108 bool isLegalOp(Operation *op); 109 110 TargetEnv targetEnv; 111 }; 112 113 /// Returns the value for the given `builtin` variable. This function gets or 114 /// inserts the global variable associated for the builtin within the nearest 115 /// enclosing op that has a symbol table. Returns null Value if such an 116 /// enclosing op cannot be found. 117 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, 118 OpBuilder &builder); 119 120 /// Performs the index computation to get to the element at `indices` of the 121 /// memory pointed to by `basePtr`, using the layout map of `baseType`. 122 123 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap 124 // that has static strides. Extend to handle dynamic strides. 125 spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, 126 MemRefType baseType, Value basePtr, 127 ValueRange indices, Location loc, 128 OpBuilder &builder); 129 130 /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its 131 /// arguments. 132 LogicalResult setABIAttrs(spirv::FuncOp funcOp, 133 EntryPointABIAttr entryPointInfo, 134 ArrayRef<InterfaceVarABIAttr> argABIInfo); 135 } // namespace spirv 136 } // namespace mlir 137 138 #endif // MLIR_DIALECT_SPIRV_SPIRVLOWERING_H 139