1 //===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
14 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
15 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Debug.h"
20
21 #include <functional>
22
23 #define DEBUG_TYPE "mlir-spirv-lowering"
24
25 using namespace mlir;
26
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30
31 /// Checks that `candidates` extension requirements are possible to be satisfied
32 /// with the given `targetEnv`.
33 ///
34 /// `candidates` is a vector of vector for extension requirements following
35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
36 /// convention.
37 template <typename LabelT>
checkExtensionRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates)38 static LogicalResult checkExtensionRequirements(
39 LabelT label, const spirv::TargetEnv &targetEnv,
40 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
41 for (const auto &ors : candidates) {
42 if (targetEnv.allows(ors))
43 continue;
44
45 SmallVector<StringRef, 4> extStrings;
46 for (spirv::Extension ext : ors)
47 extStrings.push_back(spirv::stringifyExtension(ext));
48
49 LLVM_DEBUG(llvm::dbgs()
50 << label << " illegal: requires at least one extension in ["
51 << llvm::join(extStrings, ", ")
52 << "] but none allowed in target environment\n");
53 return failure();
54 }
55 return success();
56 }
57
58 /// Checks that `candidates`capability requirements are possible to be satisfied
59 /// with the given `isAllowedFn`.
60 ///
61 /// `candidates` is a vector of vector for capability requirements following
62 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
63 /// convention.
64 template <typename LabelT>
checkCapabilityRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates)65 static LogicalResult checkCapabilityRequirements(
66 LabelT label, const spirv::TargetEnv &targetEnv,
67 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
68 for (const auto &ors : candidates) {
69 if (targetEnv.allows(ors))
70 continue;
71
72 SmallVector<StringRef, 4> capStrings;
73 for (spirv::Capability cap : ors)
74 capStrings.push_back(spirv::stringifyCapability(cap));
75
76 LLVM_DEBUG(llvm::dbgs()
77 << label << " illegal: requires at least one capability in ["
78 << llvm::join(capStrings, ", ")
79 << "] but none allowed in target environment\n");
80 return failure();
81 }
82 return success();
83 }
84
85 //===----------------------------------------------------------------------===//
86 // Type Conversion
87 //===----------------------------------------------------------------------===//
88
getIndexType(MLIRContext * context)89 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
90 // Convert to 32-bit integers for now. Might need a way to control this in
91 // future.
92 // TODO: It is probably better to make it 64-bit integers. To
93 // this some support is needed in SPIR-V dialect for Conversion
94 // instructions. The Vulkan spec requires the builtins like
95 // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
96 // SExtended to 64-bit for index computations.
97 return IntegerType::get(32, context);
98 }
99
100 /// Mapping between SPIR-V storage classes to memref memory spaces.
101 ///
102 /// Note: memref does not have a defined semantics for each memory space; it
103 /// depends on the context where it is used. There are no particular reasons
104 /// behind the number assignments; we try to follow NVVM conventions and largely
105 /// give common storage classes a smaller number. The hope is use symbolic
106 /// memory space representation eventually after memref supports it.
107 // TODO: swap Generic and StorageBuffer assignment to be more akin
108 // to NVVM.
109 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \
110 MAP_FN(spirv::StorageClass::Generic, 1) \
111 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
112 MAP_FN(spirv::StorageClass::Workgroup, 3) \
113 MAP_FN(spirv::StorageClass::Uniform, 4) \
114 MAP_FN(spirv::StorageClass::Private, 5) \
115 MAP_FN(spirv::StorageClass::Function, 6) \
116 MAP_FN(spirv::StorageClass::PushConstant, 7) \
117 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
118 MAP_FN(spirv::StorageClass::Input, 9) \
119 MAP_FN(spirv::StorageClass::Output, 10) \
120 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
121 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
122 MAP_FN(spirv::StorageClass::Image, 13) \
123 MAP_FN(spirv::StorageClass::CallableDataNV, 14) \
124 MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15) \
125 MAP_FN(spirv::StorageClass::RayPayloadNV, 16) \
126 MAP_FN(spirv::StorageClass::HitAttributeNV, 17) \
127 MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18) \
128 MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19) \
129 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
130
131 unsigned
getMemorySpaceForStorageClass(spirv::StorageClass storage)132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
133 #define STORAGE_SPACE_MAP_FN(storage, space) \
134 case storage: \
135 return space;
136
137 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
138 #undef STORAGE_SPACE_MAP_FN
139 llvm_unreachable("unhandled storage class!");
140 }
141
142 Optional<spirv::StorageClass>
getStorageClassForMemorySpace(unsigned space)143 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
144 #define STORAGE_SPACE_MAP_FN(storage, space) \
145 case space: \
146 return storage;
147
148 switch (space) {
149 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
150 default:
151 return llvm::None;
152 }
153 #undef STORAGE_SPACE_MAP_FN
154 }
155
156 #undef STORAGE_SPACE_MAP_LIST
157
158 // TODO: This is a utility function that should probably be
159 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
getTypeNumBytes(Type t)160 static Optional<int64_t> getTypeNumBytes(Type t) {
161 if (t.isa<spirv::ScalarType>()) {
162 auto bitWidth = t.getIntOrFloatBitWidth();
163 // According to the SPIR-V spec:
164 // "There is no physical size or bit pattern defined for values with boolean
165 // type. If they are stored (in conjunction with OpVariable), they can only
166 // be used with logical addressing operations, not physical, and only with
167 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
168 // Private, Function, Input, and Output."
169 if (bitWidth == 1) {
170 return llvm::None;
171 }
172 return bitWidth / 8;
173 }
174 if (auto vecType = t.dyn_cast<VectorType>()) {
175 auto elementSize = getTypeNumBytes(vecType.getElementType());
176 if (!elementSize)
177 return llvm::None;
178 return vecType.getNumElements() * *elementSize;
179 }
180 if (auto memRefType = t.dyn_cast<MemRefType>()) {
181 // TODO: Layout should also be controlled by the ABI attributes. For now
182 // using the layout from MemRef.
183 int64_t offset;
184 SmallVector<int64_t, 4> strides;
185 if (!memRefType.hasStaticShape() ||
186 failed(getStridesAndOffset(memRefType, strides, offset))) {
187 return llvm::None;
188 }
189 // To get the size of the memref object in memory, the total size is the
190 // max(stride * dimension-size) computed for all dimensions times the size
191 // of the element.
192 auto elementSize = getTypeNumBytes(memRefType.getElementType());
193 if (!elementSize) {
194 return llvm::None;
195 }
196 if (memRefType.getRank() == 0) {
197 return elementSize;
198 }
199 auto dims = memRefType.getShape();
200 if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
201 offset == MemRefType::getDynamicStrideOrOffset() ||
202 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
203 return llvm::None;
204 }
205 int64_t memrefSize = -1;
206 for (auto shape : enumerate(dims)) {
207 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
208 }
209 return (offset + memrefSize) * elementSize.getValue();
210 } else if (auto tensorType = t.dyn_cast<TensorType>()) {
211 if (!tensorType.hasStaticShape()) {
212 return llvm::None;
213 }
214 auto elementSize = getTypeNumBytes(tensorType.getElementType());
215 if (!elementSize) {
216 return llvm::None;
217 }
218 int64_t size = elementSize.getValue();
219 for (auto shape : tensorType.getShape()) {
220 size *= shape;
221 }
222 return size;
223 }
224 // TODO: Add size computation for other types.
225 return llvm::None;
226 }
227
getConvertedTypeNumBytes(Type t)228 Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
229 return getTypeNumBytes(t);
230 }
231
232 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
233 static Optional<Type>
convertScalarType(const spirv::TargetEnv & targetEnv,spirv::ScalarType type,Optional<spirv::StorageClass> storageClass={})234 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
235 Optional<spirv::StorageClass> storageClass = {}) {
236 // Get extension and capability requirements for the given type.
237 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
238 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
239 type.getExtensions(extensions, storageClass);
240 type.getCapabilities(capabilities, storageClass);
241
242 // If all requirements are met, then we can accept this type as-is.
243 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
244 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
245 return type;
246
247 // Otherwise we need to adjust the type, which really means adjusting the
248 // bitwidth given this is a scalar type.
249 // TODO: We are unconditionally converting the bitwidth here,
250 // this might be okay for non-interface types (i.e., types used in
251 // Private/Function storage classes), but not for interface types (i.e.,
252 // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
253 // This is because the later actually affects the ABI contract with the
254 // runtime. So we may want to expose a control on SPIRVTypeConverter to fail
255 // conversion if we cannot change there.
256
257 if (auto floatType = type.dyn_cast<FloatType>()) {
258 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
259 return Builder(targetEnv.getContext()).getF32Type();
260 }
261
262 auto intType = type.cast<IntegerType>();
263 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
264 return IntegerType::get(/*width=*/32, intType.getSignedness(),
265 targetEnv.getContext());
266 }
267
268 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
269 static Optional<Type>
convertVectorType(const spirv::TargetEnv & targetEnv,VectorType type,Optional<spirv::StorageClass> storageClass={})270 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
271 Optional<spirv::StorageClass> storageClass = {}) {
272 if (!spirv::CompositeType::isValid(type)) {
273 // TODO: One-element vector types can be translated into scalar
274 // types. Vector types with more than four elements can be translated into
275 // array types.
276 LLVM_DEBUG(llvm::dbgs()
277 << type << " illegal: 1- and > 4-element unimplemented\n");
278 return llvm::None;
279 }
280
281 // Get extension and capability requirements for the given type.
282 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
283 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
284 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
285 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
286
287 // If all requirements are met, then we can accept this type as-is.
288 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
289 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
290 return type;
291
292 auto elementType = convertScalarType(
293 targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
294 if (elementType)
295 return VectorType::get(type.getShape(), *elementType);
296 return llvm::None;
297 }
298
299 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
300 ///
301 /// Note that this is mainly for lowering constant tensors.In SPIR-V one can
302 /// create composite constants with OpConstantComposite to embed relative large
303 /// constant values and use OpCompositeExtract and OpCompositeInsert to
304 /// manipulate, like what we do for vectors.
convertTensorType(const spirv::TargetEnv & targetEnv,TensorType type)305 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
306 TensorType type) {
307 // TODO: Handle dynamic shapes.
308 if (!type.hasStaticShape()) {
309 LLVM_DEBUG(llvm::dbgs()
310 << type << " illegal: dynamic shape unimplemented\n");
311 return llvm::None;
312 }
313
314 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
315 if (!scalarType) {
316 LLVM_DEBUG(llvm::dbgs()
317 << type << " illegal: cannot convert non-scalar element type\n");
318 return llvm::None;
319 }
320
321 Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
322 Optional<int64_t> tensorSize = getTypeNumBytes(type);
323 if (!scalarSize || !tensorSize) {
324 LLVM_DEBUG(llvm::dbgs()
325 << type << " illegal: cannot deduce element count\n");
326 return llvm::None;
327 }
328
329 auto arrayElemCount = *tensorSize / *scalarSize;
330 auto arrayElemType = convertScalarType(targetEnv, scalarType);
331 if (!arrayElemType)
332 return llvm::None;
333 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
334 if (!arrayElemSize) {
335 LLVM_DEBUG(llvm::dbgs()
336 << type << " illegal: cannot deduce converted element size\n");
337 return llvm::None;
338 }
339
340 return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
341 }
342
convertMemrefType(const spirv::TargetEnv & targetEnv,MemRefType type)343 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
344 MemRefType type) {
345 Optional<spirv::StorageClass> storageClass =
346 SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
347 if (!storageClass) {
348 LLVM_DEBUG(llvm::dbgs()
349 << type << " illegal: cannot convert memory space\n");
350 return llvm::None;
351 }
352
353 Optional<Type> arrayElemType;
354 Type elementType = type.getElementType();
355 if (auto vecType = elementType.dyn_cast<VectorType>()) {
356 arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
357 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
358 arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
359 } else {
360 LLVM_DEBUG(
361 llvm::dbgs()
362 << type
363 << " unhandled: can only convert scalar or vector element type\n");
364 return llvm::None;
365 }
366 if (!arrayElemType)
367 return llvm::None;
368
369 Optional<int64_t> elementSize = getTypeNumBytes(elementType);
370 if (!elementSize) {
371 LLVM_DEBUG(llvm::dbgs()
372 << type << " illegal: cannot deduce element size\n");
373 return llvm::None;
374 }
375
376 if (!type.hasStaticShape()) {
377 auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
378 // Wrap in a struct to satisfy Vulkan interface requirements.
379 auto structType = spirv::StructType::get(arrayType, 0);
380 return spirv::PointerType::get(structType, *storageClass);
381 }
382
383 Optional<int64_t> memrefSize = getTypeNumBytes(type);
384 if (!memrefSize) {
385 LLVM_DEBUG(llvm::dbgs()
386 << type << " illegal: cannot deduce element count\n");
387 return llvm::None;
388 }
389
390 auto arrayElemCount = *memrefSize / *elementSize;
391
392 Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
393 if (!arrayElemSize) {
394 LLVM_DEBUG(llvm::dbgs()
395 << type << " illegal: cannot deduce converted element size\n");
396 return llvm::None;
397 }
398
399 auto arrayType =
400 spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
401
402 // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
403 // workgroup storage class do not need the struct to be laid out explicitly.
404 auto structType = *storageClass == spirv::StorageClass::Workgroup
405 ? spirv::StructType::get(arrayType)
406 : spirv::StructType::get(arrayType, 0);
407 return spirv::PointerType::get(structType, *storageClass);
408 }
409
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)410 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
411 : targetEnv(targetAttr) {
412 // Add conversions. The order matters here: later ones will be tried earlier.
413
414 // All other cases failed. Then we cannot convert this type.
415 addConversion([](Type type) { return llvm::None; });
416
417 // Allow all SPIR-V dialect specific types. This assumes all builtin types
418 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
419 // were tried before.
420 //
421 // TODO: this assumes that the SPIR-V types are valid to use in
422 // the given target environment, which should be the case if the whole
423 // pipeline is driven by the same target environment. Still, we probably still
424 // want to validate and convert to be safe.
425 addConversion([](spirv::SPIRVType type) { return type; });
426
427 addConversion([](IndexType indexType) {
428 return SPIRVTypeConverter::getIndexType(indexType.getContext());
429 });
430
431 addConversion([this](IntegerType intType) -> Optional<Type> {
432 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
433 return convertScalarType(targetEnv, scalarType);
434 return llvm::None;
435 });
436
437 addConversion([this](FloatType floatType) -> Optional<Type> {
438 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
439 return convertScalarType(targetEnv, scalarType);
440 return llvm::None;
441 });
442
443 addConversion([this](VectorType vectorType) {
444 return convertVectorType(targetEnv, vectorType);
445 });
446
447 addConversion([this](TensorType tensorType) {
448 return convertTensorType(targetEnv, tensorType);
449 });
450
451 addConversion([this](MemRefType memRefType) {
452 return convertMemrefType(targetEnv, memRefType);
453 });
454 }
455
456 //===----------------------------------------------------------------------===//
457 // FuncOp Conversion Patterns
458 //===----------------------------------------------------------------------===//
459
460 namespace {
461 /// A pattern for rewriting function signature to convert arguments of functions
462 /// to be of valid SPIR-V types.
463 class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
464 public:
465 using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
466
467 LogicalResult
468 matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
469 ConversionPatternRewriter &rewriter) const override;
470 };
471 } // namespace
472
473 LogicalResult
matchAndRewrite(FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const474 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
475 ConversionPatternRewriter &rewriter) const {
476 auto fnType = funcOp.getType();
477 // TODO: support converting functions with one result.
478 if (fnType.getNumResults())
479 return failure();
480
481 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
482 for (auto argType : enumerate(funcOp.getType().getInputs())) {
483 auto convertedType = typeConverter.convertType(argType.value());
484 if (!convertedType)
485 return failure();
486 signatureConverter.addInputs(argType.index(), convertedType);
487 }
488
489 // Create the converted spv.func op.
490 auto newFuncOp = rewriter.create<spirv::FuncOp>(
491 funcOp.getLoc(), funcOp.getName(),
492 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
493 llvm::None));
494
495 // Copy over all attributes other than the function name and type.
496 for (const auto &namedAttr : funcOp.getAttrs()) {
497 if (namedAttr.first != impl::getTypeAttrName() &&
498 namedAttr.first != SymbolTable::getSymbolAttrName())
499 newFuncOp.setAttr(namedAttr.first, namedAttr.second);
500 }
501
502 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
503 newFuncOp.end());
504 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
505 &signatureConverter)))
506 return failure();
507 rewriter.eraseOp(funcOp);
508 return success();
509 }
510
populateBuiltinFuncToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)511 void mlir::populateBuiltinFuncToSPIRVPatterns(
512 MLIRContext *context, SPIRVTypeConverter &typeConverter,
513 OwningRewritePatternList &patterns) {
514 patterns.insert<FuncOpConversion>(context, typeConverter);
515 }
516
517 //===----------------------------------------------------------------------===//
518 // Builtin Variables
519 //===----------------------------------------------------------------------===//
520
getBuiltinVariable(Block & body,spirv::BuiltIn builtin)521 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
522 spirv::BuiltIn builtin) {
523 // Look through all global variables in the given `body` block and check if
524 // there is a spv.globalVariable that has the same `builtin` attribute.
525 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
526 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
527 spirv::SPIRVDialect::getAttributeName(
528 spirv::Decoration::BuiltIn))) {
529 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
530 if (varBuiltIn && varBuiltIn.getValue() == builtin) {
531 return varOp;
532 }
533 }
534 }
535 return nullptr;
536 }
537
538 /// Gets name of global variable for a builtin.
getBuiltinVarName(spirv::BuiltIn builtin)539 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
540 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
541 }
542
543 /// Gets or inserts a global variable for a builtin within `body` block.
544 static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(Block & body,Location loc,spirv::BuiltIn builtin,OpBuilder & builder)545 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
546 OpBuilder &builder) {
547 if (auto varOp = getBuiltinVariable(body, builtin))
548 return varOp;
549
550 OpBuilder::InsertionGuard guard(builder);
551 builder.setInsertionPointToStart(&body);
552
553 spirv::GlobalVariableOp newVarOp;
554 switch (builtin) {
555 case spirv::BuiltIn::NumWorkgroups:
556 case spirv::BuiltIn::WorkgroupSize:
557 case spirv::BuiltIn::WorkgroupId:
558 case spirv::BuiltIn::LocalInvocationId:
559 case spirv::BuiltIn::GlobalInvocationId: {
560 auto ptrType = spirv::PointerType::get(
561 VectorType::get({3}, builder.getIntegerType(32)),
562 spirv::StorageClass::Input);
563 std::string name = getBuiltinVarName(builtin);
564 newVarOp =
565 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
566 break;
567 }
568 case spirv::BuiltIn::SubgroupId:
569 case spirv::BuiltIn::NumSubgroups:
570 case spirv::BuiltIn::SubgroupSize: {
571 auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
572 spirv::StorageClass::Input);
573 std::string name = getBuiltinVarName(builtin);
574 newVarOp =
575 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
576 break;
577 }
578 default:
579 emitError(loc, "unimplemented builtin variable generation for ")
580 << stringifyBuiltIn(builtin);
581 }
582 return newVarOp;
583 }
584
getBuiltinVariableValue(Operation * op,spirv::BuiltIn builtin,OpBuilder & builder)585 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
586 spirv::BuiltIn builtin,
587 OpBuilder &builder) {
588 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
589 if (!parent) {
590 op->emitError("expected operation to be within a module-like op");
591 return nullptr;
592 }
593
594 spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
595 *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
596 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
597 return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
598 }
599
600 //===----------------------------------------------------------------------===//
601 // Index calculation
602 //===----------------------------------------------------------------------===//
603
getElementPtr(SPIRVTypeConverter & typeConverter,MemRefType baseType,Value basePtr,ValueRange indices,Location loc,OpBuilder & builder)604 spirv::AccessChainOp mlir::spirv::getElementPtr(
605 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
606 ValueRange indices, Location loc, OpBuilder &builder) {
607 // Get base and offset of the MemRefType and verify they are static.
608
609 int64_t offset;
610 SmallVector<int64_t, 4> strides;
611 if (failed(getStridesAndOffset(baseType, strides, offset)) ||
612 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
613 offset == MemRefType::getDynamicStrideOrOffset()) {
614 return nullptr;
615 }
616
617 auto indexType = typeConverter.getIndexType(builder.getContext());
618
619 SmallVector<Value, 2> linearizedIndices;
620 // Add a '0' at the start to index into the struct.
621 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
622 linearizedIndices.push_back(zero);
623
624 if (baseType.getRank() == 0) {
625 linearizedIndices.push_back(zero);
626 } else {
627 // TODO: Instead of this logic, use affine.apply and add patterns for
628 // lowering affine.apply to standard ops. These will get lowered to SPIR-V
629 // ops by the DialectConversion framework.
630 Value ptrLoc = builder.create<spirv::ConstantOp>(
631 loc, indexType, IntegerAttr::get(indexType, offset));
632 assert(indices.size() == strides.size() &&
633 "must provide indices for all dimensions");
634 for (auto index : llvm::enumerate(indices)) {
635 Value strideVal = builder.create<spirv::ConstantOp>(
636 loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
637 Value update =
638 builder.create<spirv::IMulOp>(loc, strideVal, index.value());
639 ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
640 }
641 linearizedIndices.push_back(ptrLoc);
642 }
643 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
644 }
645
646 //===----------------------------------------------------------------------===//
647 // Set ABI attributes for lowering entry functions.
648 //===----------------------------------------------------------------------===//
649
650 LogicalResult
setABIAttrs(spirv::FuncOp funcOp,spirv::EntryPointABIAttr entryPointInfo,ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo)651 mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
652 spirv::EntryPointABIAttr entryPointInfo,
653 ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
654 // Set the attributes for argument and the function.
655 StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
656 for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
657 funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
658 }
659 funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
660 return success();
661 }
662
663 //===----------------------------------------------------------------------===//
664 // SPIR-V ConversionTarget
665 //===----------------------------------------------------------------------===//
666
667 std::unique_ptr<spirv::SPIRVConversionTarget>
get(spirv::TargetEnvAttr targetAttr)668 spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
669 std::unique_ptr<SPIRVConversionTarget> target(
670 // std::make_unique does not work here because the constructor is private.
671 new SPIRVConversionTarget(targetAttr));
672 SPIRVConversionTarget *targetPtr = target.get();
673 target->addDynamicallyLegalDialect<SPIRVDialect>(
674 // We need to capture the raw pointer here because it is stable:
675 // target will be destroyed once this function is returned.
676 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
677 return target;
678 }
679
SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)680 spirv::SPIRVConversionTarget::SPIRVConversionTarget(
681 spirv::TargetEnvAttr targetAttr)
682 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
683
isLegalOp(Operation * op)684 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
685 // Make sure this op is available at the given version. Ops not implementing
686 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
687 // SPIR-V versions.
688 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
689 if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
690 LLVM_DEBUG(llvm::dbgs()
691 << op->getName() << " illegal: requiring min version "
692 << spirv::stringifyVersion(minVersion.getMinVersion())
693 << "\n");
694 return false;
695 }
696 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
697 if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
698 LLVM_DEBUG(llvm::dbgs()
699 << op->getName() << " illegal: requiring max version "
700 << spirv::stringifyVersion(maxVersion.getMaxVersion())
701 << "\n");
702 return false;
703 }
704
705 // Make sure this op's required extensions are allowed to use. Ops not
706 // implementing QueryExtensionInterface do not require extensions to be
707 // available.
708 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
709 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
710 extensions.getExtensions())))
711 return false;
712
713 // Make sure this op's required extensions are allowed to use. Ops not
714 // implementing QueryCapabilityInterface do not require capabilities to be
715 // available.
716 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
717 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
718 capabilities.getCapabilities())))
719 return false;
720
721 SmallVector<Type, 4> valueTypes;
722 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
723 valueTypes.append(op->result_type_begin(), op->result_type_end());
724
725 // Special treatment for global variables, whose type requirements are
726 // conveyed by type attributes.
727 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
728 valueTypes.push_back(globalVar.type());
729
730 // Make sure the op's operands/results use types that are allowed by the
731 // target environment.
732 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
733 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
734 for (Type valueType : valueTypes) {
735 typeExtensions.clear();
736 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
737 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
738 typeExtensions)))
739 return false;
740
741 typeCapabilities.clear();
742 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
743 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
744 typeCapabilities)))
745 return false;
746 }
747
748 return true;
749 }
750