1 //===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the types for the LLVM dialect in MLIR. These MLIR types
10 // correspond to the LLVM IR type system.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "TypeDetail.h"
15
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/TypeSupport.h"
20
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/TypeSize.h"
23
24 using namespace mlir;
25 using namespace mlir::LLVM;
26
27 //===----------------------------------------------------------------------===//
28 // LLVMType.
29 //===----------------------------------------------------------------------===//
30
classof(Type type)31 bool LLVMType::classof(Type type) {
32 return llvm::isa<LLVMDialect>(type.getDialect());
33 }
34
getDialect()35 LLVMDialect &LLVMType::getDialect() {
36 return static_cast<LLVMDialect &>(Type::getDialect());
37 }
38
39 //----------------------------------------------------------------------------//
40 // Misc type utilities.
41
getPrimitiveSizeInBits()42 llvm::TypeSize LLVMType::getPrimitiveSizeInBits() {
43 return llvm::TypeSwitch<LLVMType, llvm::TypeSize>(*this)
44 .Case<LLVMHalfType, LLVMBFloatType>(
45 [](LLVMType) { return llvm::TypeSize::Fixed(16); })
46 .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
47 .Case<LLVMDoubleType, LLVMX86MMXType>(
48 [](LLVMType) { return llvm::TypeSize::Fixed(64); })
49 .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
50 return llvm::TypeSize::Fixed(intTy.getBitWidth());
51 })
52 .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
53 .Case<LLVMPPCFP128Type, LLVMFP128Type>(
54 [](LLVMType) { return llvm::TypeSize::Fixed(128); })
55 .Case<LLVMVectorType>([](LLVMVectorType t) {
56 llvm::TypeSize elementSize =
57 t.getElementType().getPrimitiveSizeInBits();
58 llvm::ElementCount elementCount = t.getElementCount();
59 assert(!elementSize.isScalable() &&
60 "vector type should have fixed-width elements");
61 return llvm::TypeSize(elementSize.getFixedSize() *
62 elementCount.getKnownMinValue(),
63 elementCount.isScalable());
64 })
65 .Default([](LLVMType ty) {
66 assert((ty.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
67 LLVMTokenType, LLVMStructType, LLVMArrayType,
68 LLVMPointerType, LLVMFunctionType>()) &&
69 "unexpected missing support for primitive type");
70 return llvm::TypeSize::Fixed(0);
71 });
72 }
73
74 //----------------------------------------------------------------------------//
75 // Integer type utilities.
76
isIntegerTy(unsigned bitwidth)77 bool LLVMType::isIntegerTy(unsigned bitwidth) {
78 if (auto intType = dyn_cast<LLVMIntegerType>())
79 return intType.getBitWidth() == bitwidth;
80 return false;
81 }
getIntegerBitWidth()82 unsigned LLVMType::getIntegerBitWidth() {
83 return cast<LLVMIntegerType>().getBitWidth();
84 }
85
getArrayElementType()86 LLVMType LLVMType::getArrayElementType() {
87 return cast<LLVMArrayType>().getElementType();
88 }
89
90 //----------------------------------------------------------------------------//
91 // Array type utilities.
92
getArrayNumElements()93 unsigned LLVMType::getArrayNumElements() {
94 return cast<LLVMArrayType>().getNumElements();
95 }
96
isArrayTy()97 bool LLVMType::isArrayTy() { return isa<LLVMArrayType>(); }
98
99 //----------------------------------------------------------------------------//
100 // Vector type utilities.
101
getVectorElementType()102 LLVMType LLVMType::getVectorElementType() {
103 return cast<LLVMVectorType>().getElementType();
104 }
105
getVectorNumElements()106 unsigned LLVMType::getVectorNumElements() {
107 return cast<LLVMFixedVectorType>().getNumElements();
108 }
getVectorElementCount()109 llvm::ElementCount LLVMType::getVectorElementCount() {
110 return cast<LLVMVectorType>().getElementCount();
111 }
112
isVectorTy()113 bool LLVMType::isVectorTy() { return isa<LLVMVectorType>(); }
114
115 //----------------------------------------------------------------------------//
116 // Function type utilities.
117
getFunctionParamType(unsigned argIdx)118 LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
119 return cast<LLVMFunctionType>().getParamType(argIdx);
120 }
121
getFunctionNumParams()122 unsigned LLVMType::getFunctionNumParams() {
123 return cast<LLVMFunctionType>().getNumParams();
124 }
125
getFunctionResultType()126 LLVMType LLVMType::getFunctionResultType() {
127 return cast<LLVMFunctionType>().getReturnType();
128 }
129
isFunctionTy()130 bool LLVMType::isFunctionTy() { return isa<LLVMFunctionType>(); }
131
isFunctionVarArg()132 bool LLVMType::isFunctionVarArg() {
133 return cast<LLVMFunctionType>().isVarArg();
134 }
135
136 //----------------------------------------------------------------------------//
137 // Pointer type utilities.
138
getPointerTo(unsigned addrSpace)139 LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
140 return LLVMPointerType::get(*this, addrSpace);
141 }
142
getPointerElementTy()143 LLVMType LLVMType::getPointerElementTy() {
144 return cast<LLVMPointerType>().getElementType();
145 }
146
isPointerTy()147 bool LLVMType::isPointerTy() { return isa<LLVMPointerType>(); }
148
149 //----------------------------------------------------------------------------//
150 // Struct type utilities.
151
getStructElementType(unsigned i)152 LLVMType LLVMType::getStructElementType(unsigned i) {
153 return cast<LLVMStructType>().getBody()[i];
154 }
155
getStructNumElements()156 unsigned LLVMType::getStructNumElements() {
157 return cast<LLVMStructType>().getBody().size();
158 }
159
isStructTy()160 bool LLVMType::isStructTy() { return isa<LLVMStructType>(); }
161
162 //----------------------------------------------------------------------------//
163 // Utilities used to generate floating point types.
164
getDoubleTy(MLIRContext * context)165 LLVMType LLVMType::getDoubleTy(MLIRContext *context) {
166 return LLVMDoubleType::get(context);
167 }
168
getFloatTy(MLIRContext * context)169 LLVMType LLVMType::getFloatTy(MLIRContext *context) {
170 return LLVMFloatType::get(context);
171 }
172
getBFloatTy(MLIRContext * context)173 LLVMType LLVMType::getBFloatTy(MLIRContext *context) {
174 return LLVMBFloatType::get(context);
175 }
176
getHalfTy(MLIRContext * context)177 LLVMType LLVMType::getHalfTy(MLIRContext *context) {
178 return LLVMHalfType::get(context);
179 }
180
getFP128Ty(MLIRContext * context)181 LLVMType LLVMType::getFP128Ty(MLIRContext *context) {
182 return LLVMFP128Type::get(context);
183 }
184
getX86_FP80Ty(MLIRContext * context)185 LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) {
186 return LLVMX86FP80Type::get(context);
187 }
188
189 //----------------------------------------------------------------------------//
190 // Utilities used to generate integer types.
191
getIntNTy(MLIRContext * context,unsigned numBits)192 LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) {
193 return LLVMIntegerType::get(context, numBits);
194 }
195
196 //----------------------------------------------------------------------------//
197 // Utilities used to generate other miscellaneous types.
198
getArrayTy(LLVMType elementType,uint64_t numElements)199 LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
200 return LLVMArrayType::get(elementType, numElements);
201 }
202
getFunctionTy(LLVMType result,ArrayRef<LLVMType> params,bool isVarArg)203 LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
204 bool isVarArg) {
205 return LLVMFunctionType::get(result, params, isVarArg);
206 }
207
getStructTy(MLIRContext * context,ArrayRef<LLVMType> elements,bool isPacked)208 LLVMType LLVMType::getStructTy(MLIRContext *context,
209 ArrayRef<LLVMType> elements, bool isPacked) {
210 return LLVMStructType::getLiteral(context, elements, isPacked);
211 }
212
getVectorTy(LLVMType elementType,unsigned numElements)213 LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
214 return LLVMFixedVectorType::get(elementType, numElements);
215 }
216
217 //----------------------------------------------------------------------------//
218 // Void type utilities.
219
getVoidTy(MLIRContext * context)220 LLVMType LLVMType::getVoidTy(MLIRContext *context) {
221 return LLVMVoidType::get(context);
222 }
223
isVoidTy()224 bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); }
225
226 //----------------------------------------------------------------------------//
227 // Creation and setting of LLVM's identified struct types
228
createStructTy(MLIRContext * context,ArrayRef<LLVMType> elements,Optional<StringRef> name,bool isPacked)229 LLVMType LLVMType::createStructTy(MLIRContext *context,
230 ArrayRef<LLVMType> elements,
231 Optional<StringRef> name, bool isPacked) {
232 assert(name.hasValue() &&
233 "identified structs with no identifier not supported");
234 StringRef stringNameBase = name.getValueOr("");
235 std::string stringName = stringNameBase.str();
236 unsigned counter = 0;
237 do {
238 auto type = LLVMStructType::getIdentified(context, stringName);
239 if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
240 counter += 1;
241 stringName =
242 (Twine(stringNameBase) + "." + std::to_string(counter)).str();
243 continue;
244 }
245 return type;
246 } while (true);
247 }
248
setStructTyBody(LLVMType structType,ArrayRef<LLVMType> elements,bool isPacked)249 LLVMType LLVMType::setStructTyBody(LLVMType structType,
250 ArrayRef<LLVMType> elements, bool isPacked) {
251 LogicalResult couldSet =
252 structType.cast<LLVMStructType>().setBody(elements, isPacked);
253 assert(succeeded(couldSet) && "failed to set the body");
254 (void)couldSet;
255 return structType;
256 }
257
258 //===----------------------------------------------------------------------===//
259 // Array type.
260
isValidElementType(LLVMType type)261 bool LLVMArrayType::isValidElementType(LLVMType type) {
262 return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
263 LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
264 }
265
get(LLVMType elementType,unsigned numElements)266 LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
267 assert(elementType && "expected non-null subtype");
268 return Base::get(elementType.getContext(), elementType, numElements);
269 }
270
getChecked(Location loc,LLVMType elementType,unsigned numElements)271 LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
272 unsigned numElements) {
273 assert(elementType && "expected non-null subtype");
274 return Base::getChecked(loc, elementType, numElements);
275 }
276
getElementType()277 LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
278
getNumElements()279 unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; }
280
281 LogicalResult
verifyConstructionInvariants(Location loc,LLVMType elementType,unsigned numElements)282 LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType,
283 unsigned numElements) {
284 if (!isValidElementType(elementType))
285 return emitError(loc, "invalid array element type: ") << elementType;
286 return success();
287 }
288
289 //===----------------------------------------------------------------------===//
290 // Function type.
291
isValidArgumentType(LLVMType type)292 bool LLVMFunctionType::isValidArgumentType(LLVMType type) {
293 return !type.isa<LLVMVoidType, LLVMFunctionType>();
294 }
295
isValidResultType(LLVMType type)296 bool LLVMFunctionType::isValidResultType(LLVMType type) {
297 return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>();
298 }
299
get(LLVMType result,ArrayRef<LLVMType> arguments,bool isVarArg)300 LLVMFunctionType LLVMFunctionType::get(LLVMType result,
301 ArrayRef<LLVMType> arguments,
302 bool isVarArg) {
303 assert(result && "expected non-null result");
304 return Base::get(result.getContext(), result, arguments, isVarArg);
305 }
306
getChecked(Location loc,LLVMType result,ArrayRef<LLVMType> arguments,bool isVarArg)307 LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
308 ArrayRef<LLVMType> arguments,
309 bool isVarArg) {
310 assert(result && "expected non-null result");
311 return Base::getChecked(loc, result, arguments, isVarArg);
312 }
313
getReturnType()314 LLVMType LLVMFunctionType::getReturnType() {
315 return getImpl()->getReturnType();
316 }
317
getNumParams()318 unsigned LLVMFunctionType::getNumParams() {
319 return getImpl()->getArgumentTypes().size();
320 }
321
getParamType(unsigned i)322 LLVMType LLVMFunctionType::getParamType(unsigned i) {
323 return getImpl()->getArgumentTypes()[i];
324 }
325
isVarArg()326 bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); }
327
getParams()328 ArrayRef<LLVMType> LLVMFunctionType::getParams() {
329 return getImpl()->getArgumentTypes();
330 }
331
verifyConstructionInvariants(Location loc,LLVMType result,ArrayRef<LLVMType> arguments,bool)332 LogicalResult LLVMFunctionType::verifyConstructionInvariants(
333 Location loc, LLVMType result, ArrayRef<LLVMType> arguments, bool) {
334 if (!isValidResultType(result))
335 return emitError(loc, "invalid function result type: ") << result;
336
337 for (LLVMType arg : arguments)
338 if (!isValidArgumentType(arg))
339 return emitError(loc, "invalid function argument type: ") << arg;
340
341 return success();
342 }
343
344 //===----------------------------------------------------------------------===//
345 // Integer type.
346
get(MLIRContext * ctx,unsigned bitwidth)347 LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
348 return Base::get(ctx, bitwidth);
349 }
350
getChecked(Location loc,unsigned bitwidth)351 LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
352 return Base::getChecked(loc, bitwidth);
353 }
354
getBitWidth()355 unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
356
verifyConstructionInvariants(Location loc,unsigned bitwidth)357 LogicalResult LLVMIntegerType::verifyConstructionInvariants(Location loc,
358 unsigned bitwidth) {
359 constexpr int maxSupportedBitwidth = (1 << 24);
360 if (bitwidth >= maxSupportedBitwidth)
361 return emitError(loc, "integer type too wide");
362 return success();
363 }
364
365 //===----------------------------------------------------------------------===//
366 // Pointer type.
367
isValidElementType(LLVMType type)368 bool LLVMPointerType::isValidElementType(LLVMType type) {
369 return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
370 LLVMLabelType>();
371 }
372
get(LLVMType pointee,unsigned addressSpace)373 LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
374 assert(pointee && "expected non-null subtype");
375 return Base::get(pointee.getContext(), pointee, addressSpace);
376 }
377
getChecked(Location loc,LLVMType pointee,unsigned addressSpace)378 LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
379 unsigned addressSpace) {
380 return Base::getChecked(loc, pointee, addressSpace);
381 }
382
getElementType()383 LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
384
getAddressSpace()385 unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; }
386
verifyConstructionInvariants(Location loc,LLVMType pointee,unsigned)387 LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
388 LLVMType pointee,
389 unsigned) {
390 if (!isValidElementType(pointee))
391 return emitError(loc, "invalid pointer element type: ") << pointee;
392 return success();
393 }
394
395 //===----------------------------------------------------------------------===//
396 // Struct type.
397
isValidElementType(LLVMType type)398 bool LLVMStructType::isValidElementType(LLVMType type) {
399 return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
400 LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
401 }
402
getIdentified(MLIRContext * context,StringRef name)403 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
404 StringRef name) {
405 return Base::get(context, name, /*opaque=*/false);
406 }
407
getIdentifiedChecked(Location loc,StringRef name)408 LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
409 StringRef name) {
410 return Base::getChecked(loc, name, /*opaque=*/false);
411 }
412
getLiteral(MLIRContext * context,ArrayRef<LLVMType> types,bool isPacked)413 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
414 ArrayRef<LLVMType> types,
415 bool isPacked) {
416 return Base::get(context, types, isPacked);
417 }
418
getLiteralChecked(Location loc,ArrayRef<LLVMType> types,bool isPacked)419 LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
420 ArrayRef<LLVMType> types,
421 bool isPacked) {
422 return Base::getChecked(loc, types, isPacked);
423 }
424
getOpaque(StringRef name,MLIRContext * context)425 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
426 return Base::get(context, name, /*opaque=*/true);
427 }
428
getOpaqueChecked(Location loc,StringRef name)429 LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
430 return Base::getChecked(loc, name, /*opaque=*/true);
431 }
432
setBody(ArrayRef<LLVMType> types,bool isPacked)433 LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
434 assert(isIdentified() && "can only set bodies of identified structs");
435 assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
436 "expected valid body types");
437 return Base::mutate(types, isPacked);
438 }
439
isPacked()440 bool LLVMStructType::isPacked() { return getImpl()->isPacked(); }
isIdentified()441 bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); }
isOpaque()442 bool LLVMStructType::isOpaque() {
443 return getImpl()->isIdentified() &&
444 (getImpl()->isOpaque() || !getImpl()->isInitialized());
445 }
isInitialized()446 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
getName()447 StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
getBody()448 ArrayRef<LLVMType> LLVMStructType::getBody() {
449 return isIdentified() ? getImpl()->getIdentifiedStructBody()
450 : getImpl()->getTypeList();
451 }
452
verifyConstructionInvariants(Location,StringRef,bool)453 LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef,
454 bool) {
455 return success();
456 }
457
458 LogicalResult
verifyConstructionInvariants(Location loc,ArrayRef<LLVMType> types,bool)459 LLVMStructType::verifyConstructionInvariants(Location loc,
460 ArrayRef<LLVMType> types, bool) {
461 for (LLVMType t : types)
462 if (!isValidElementType(t))
463 return emitError(loc, "invalid LLVM structure element type: ") << t;
464
465 return success();
466 }
467
468 //===----------------------------------------------------------------------===//
469 // Vector types.
470
isValidElementType(LLVMType type)471 bool LLVMVectorType::isValidElementType(LLVMType type) {
472 return type.isa<LLVMIntegerType, LLVMPointerType>() ||
473 type.isFloatingPointTy();
474 }
475
476 /// Support type casting functionality.
classof(Type type)477 bool LLVMVectorType::classof(Type type) {
478 return type.isa<LLVMFixedVectorType, LLVMScalableVectorType>();
479 }
480
getElementType()481 LLVMType LLVMVectorType::getElementType() {
482 // Both derived classes share the implementation type.
483 return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
484 }
485
getElementCount()486 llvm::ElementCount LLVMVectorType::getElementCount() {
487 // Both derived classes share the implementation type.
488 return llvm::ElementCount::get(
489 static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->numElements,
490 isa<LLVMScalableVectorType>());
491 }
492
493 /// Verifies that the type about to be constructed is well-formed.
494 LogicalResult
verifyConstructionInvariants(Location loc,LLVMType elementType,unsigned numElements)495 LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
496 unsigned numElements) {
497 if (numElements == 0)
498 return emitError(loc, "the number of vector elements must be positive");
499
500 if (!isValidElementType(elementType))
501 return emitError(loc, "invalid vector element type");
502
503 return success();
504 }
505
get(LLVMType elementType,unsigned numElements)506 LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
507 unsigned numElements) {
508 assert(elementType && "expected non-null subtype");
509 return Base::get(elementType.getContext(), elementType, numElements);
510 }
511
getChecked(Location loc,LLVMType elementType,unsigned numElements)512 LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
513 LLVMType elementType,
514 unsigned numElements) {
515 assert(elementType && "expected non-null subtype");
516 return Base::getChecked(loc, elementType, numElements);
517 }
518
getNumElements()519 unsigned LLVMFixedVectorType::getNumElements() {
520 return getImpl()->numElements;
521 }
522
get(LLVMType elementType,unsigned minNumElements)523 LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
524 unsigned minNumElements) {
525 assert(elementType && "expected non-null subtype");
526 return Base::get(elementType.getContext(), elementType, minNumElements);
527 }
528
529 LLVMScalableVectorType
getChecked(Location loc,LLVMType elementType,unsigned minNumElements)530 LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
531 unsigned minNumElements) {
532 assert(elementType && "expected non-null subtype");
533 return Base::getChecked(loc, elementType, minNumElements);
534 }
535
getMinNumElements()536 unsigned LLVMScalableVectorType::getMinNumElements() {
537 return getImpl()->numElements;
538 }
539