• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- SPIRVLowering.h - SPIR-V lowering utilities  -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines utilities to use while targeting SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
14 #define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
15 
16 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/SmallSet.h"
21 
22 namespace mlir {
23 
24 /// Type conversion from builtin types to SPIR-V types for shader interface.
25 ///
26 /// Non-32-bit scalar types require special hardware support that may not exist
27 /// on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types
28 /// require special capabilities or extensions. Right now if a scalar type of a
29 /// certain bitwidth is not supported in the target environment, we use 32-bit
30 /// ones unconditionally. This requires the runtime to also feed in data with
31 /// a matched bitwidth and layout for interface types. The runtime can do that
32 /// by inspecting the SPIR-V module.
33 ///
34 /// For memref types, this converter additionally performs type wrapping to
35 /// satisfy shader interface requirements: shader interface types must be
36 /// pointers to structs.
37 ///
38 /// TODO: We might want to introduce a way to control how unsupported bitwidth
39 /// are handled and explicitly fail if wanted.
40 class SPIRVTypeConverter : public TypeConverter {
41 public:
42   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);
43 
44   /// Gets the number of bytes used for a type when converted to SPIR-V
45   /// type. Note that it doesnt account for whether the type is legal for a
46   /// SPIR-V target (described by spirv::TargetEnvAttr). Returns None on
47   /// failure.
48   static Optional<int64_t> getConvertedTypeNumBytes(Type);
49 
50   /// Gets the SPIR-V correspondence for the standard index type.
51   static Type getIndexType(MLIRContext *context);
52 
53   /// Returns the corresponding memory space for memref given a SPIR-V storage
54   /// class.
55   static unsigned getMemorySpaceForStorageClass(spirv::StorageClass);
56 
57   /// Returns the SPIR-V storage class given a memory space for memref. Return
58   /// llvm::None if the memory space does not map to any SPIR-V storage class.
59   static Optional<spirv::StorageClass>
60   getStorageClassForMemorySpace(unsigned space);
61 
62 private:
63   spirv::TargetEnv targetEnv;
64 };
65 
66 /// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
67 template <typename SourceOp>
68 class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
69 public:
70   SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
71                   PatternBenefit benefit = 1)
72       : OpConversionPattern<SourceOp>(context, benefit),
73         typeConverter(typeConverter) {}
74 
75 protected:
76   SPIRVTypeConverter &typeConverter;
77 };
78 
79 /// Appends to a pattern list additional patterns for translating the builtin
80 /// `func` op to the SPIR-V dialect. These patterns do not handle shader
81 /// interface/ABI; they convert function parameters to be of SPIR-V allowed
82 /// types.
83 void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context,
84                                         SPIRVTypeConverter &typeConverter,
85                                         OwningRewritePatternList &patterns);
86 
87 namespace spirv {
88 class AccessChainOp;
89 class FuncOp;
90 
91 class SPIRVConversionTarget : public ConversionTarget {
92 public:
93   /// Creates a SPIR-V conversion target for the given target environment.
94   static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr);
95 
96 private:
97   explicit SPIRVConversionTarget(TargetEnvAttr targetAttr);
98 
99   // Be explicit that instance of this class cannot be copied or moved: there
100   // are lambdas capturing fields of the instance.
101   SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
102   SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
103   SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
104   SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
105 
106   /// Returns true if the given `op` is legal to use under the current target
107   /// environment.
108   bool isLegalOp(Operation *op);
109 
110   TargetEnv targetEnv;
111 };
112 
113 /// Returns the value for the given `builtin` variable. This function gets or
114 /// inserts the global variable associated for the builtin within the nearest
115 /// enclosing op that has a symbol table. Returns null Value if such an
116 /// enclosing op cannot be found.
117 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
118                               OpBuilder &builder);
119 
120 /// Performs the index computation to get to the element at `indices` of the
121 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
122 
123 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
124 // that has static strides. Extend to handle dynamic strides.
125 spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
126                                    MemRefType baseType, Value basePtr,
127                                    ValueRange indices, Location loc,
128                                    OpBuilder &builder);
129 
130 /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
131 /// arguments.
132 LogicalResult setABIAttrs(spirv::FuncOp funcOp,
133                           EntryPointABIAttr entryPointInfo,
134                           ArrayRef<InterfaceVarABIAttr> argABIInfo);
135 } // namespace spirv
136 } // namespace mlir
137 
138 #endif // MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
139