1 //===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Target/LLVMIR/TypeTranslation.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/MLIRContext.h"
12
13 #include "llvm/ADT/TypeSwitch.h"
14 #include "llvm/IR/DataLayout.h"
15 #include "llvm/IR/DerivedTypes.h"
16 #include "llvm/IR/Type.h"
17
18 using namespace mlir;
19
20 namespace mlir {
21 namespace LLVM {
22 namespace detail {
23 /// Support for translating MLIR LLVM dialect types to LLVM IR.
24 class TypeToLLVMIRTranslatorImpl {
25 public:
26 /// Constructs a class creating types in the given LLVM context.
TypeToLLVMIRTranslatorImpl(llvm::LLVMContext & context)27 TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {}
28
29 /// Translates a single type.
translateType(LLVM::LLVMType type)30 llvm::Type *translateType(LLVM::LLVMType type) {
31 // If the conversion is already known, just return it.
32 if (knownTranslations.count(type))
33 return knownTranslations.lookup(type);
34
35 // Dispatch to an appropriate function.
36 llvm::Type *translated =
37 llvm::TypeSwitch<LLVM::LLVMType, llvm::Type *>(type)
38 .Case([this](LLVM::LLVMVoidType) {
39 return llvm::Type::getVoidTy(context);
40 })
41 .Case([this](LLVM::LLVMHalfType) {
42 return llvm::Type::getHalfTy(context);
43 })
44 .Case([this](LLVM::LLVMBFloatType) {
45 return llvm::Type::getBFloatTy(context);
46 })
47 .Case([this](LLVM::LLVMFloatType) {
48 return llvm::Type::getFloatTy(context);
49 })
50 .Case([this](LLVM::LLVMDoubleType) {
51 return llvm::Type::getDoubleTy(context);
52 })
53 .Case([this](LLVM::LLVMFP128Type) {
54 return llvm::Type::getFP128Ty(context);
55 })
56 .Case([this](LLVM::LLVMX86FP80Type) {
57 return llvm::Type::getX86_FP80Ty(context);
58 })
59 .Case([this](LLVM::LLVMPPCFP128Type) {
60 return llvm::Type::getPPC_FP128Ty(context);
61 })
62 .Case([this](LLVM::LLVMX86MMXType) {
63 return llvm::Type::getX86_MMXTy(context);
64 })
65 .Case([this](LLVM::LLVMTokenType) {
66 return llvm::Type::getTokenTy(context);
67 })
68 .Case([this](LLVM::LLVMLabelType) {
69 return llvm::Type::getLabelTy(context);
70 })
71 .Case([this](LLVM::LLVMMetadataType) {
72 return llvm::Type::getMetadataTy(context);
73 })
74 .Case<LLVM::LLVMArrayType, LLVM::LLVMIntegerType,
75 LLVM::LLVMFunctionType, LLVM::LLVMPointerType,
76 LLVM::LLVMStructType, LLVM::LLVMFixedVectorType,
77 LLVM::LLVMScalableVectorType>(
78 [this](auto type) { return this->translate(type); })
79 .Default([](LLVM::LLVMType t) -> llvm::Type * {
80 llvm_unreachable("unknown LLVM dialect type");
81 });
82
83 // Cache the result of the conversion and return.
84 knownTranslations.try_emplace(type, translated);
85 return translated;
86 }
87
88 private:
89 /// Translates the given array type.
translate(LLVM::LLVMArrayType type)90 llvm::Type *translate(LLVM::LLVMArrayType type) {
91 return llvm::ArrayType::get(translateType(type.getElementType()),
92 type.getNumElements());
93 }
94
95 /// Translates the given function type.
translate(LLVM::LLVMFunctionType type)96 llvm::Type *translate(LLVM::LLVMFunctionType type) {
97 SmallVector<llvm::Type *, 8> paramTypes;
98 translateTypes(type.getParams(), paramTypes);
99 return llvm::FunctionType::get(translateType(type.getReturnType()),
100 paramTypes, type.isVarArg());
101 }
102
103 /// Translates the given integer type.
translate(LLVM::LLVMIntegerType type)104 llvm::Type *translate(LLVM::LLVMIntegerType type) {
105 return llvm::IntegerType::get(context, type.getBitWidth());
106 }
107
108 /// Translates the given pointer type.
translate(LLVM::LLVMPointerType type)109 llvm::Type *translate(LLVM::LLVMPointerType type) {
110 return llvm::PointerType::get(translateType(type.getElementType()),
111 type.getAddressSpace());
112 }
113
114 /// Translates the given structure type, supports both identified and literal
115 /// structs. This will _create_ a new identified structure every time, use
116 /// `convertType` if a structure with the same name must be looked up instead.
translate(LLVM::LLVMStructType type)117 llvm::Type *translate(LLVM::LLVMStructType type) {
118 SmallVector<llvm::Type *, 8> subtypes;
119 if (!type.isIdentified()) {
120 translateTypes(type.getBody(), subtypes);
121 return llvm::StructType::get(context, subtypes, type.isPacked());
122 }
123
124 llvm::StructType *structType =
125 llvm::StructType::create(context, type.getName());
126 // Mark the type we just created as known so that recursive calls can pick
127 // it up and use directly.
128 knownTranslations.try_emplace(type, structType);
129 if (type.isOpaque())
130 return structType;
131
132 translateTypes(type.getBody(), subtypes);
133 structType->setBody(subtypes, type.isPacked());
134 return structType;
135 }
136
137 /// Translates the given fixed-vector type.
translate(LLVM::LLVMFixedVectorType type)138 llvm::Type *translate(LLVM::LLVMFixedVectorType type) {
139 return llvm::FixedVectorType::get(translateType(type.getElementType()),
140 type.getNumElements());
141 }
142
143 /// Translates the given scalable-vector type.
translate(LLVM::LLVMScalableVectorType type)144 llvm::Type *translate(LLVM::LLVMScalableVectorType type) {
145 return llvm::ScalableVectorType::get(translateType(type.getElementType()),
146 type.getMinNumElements());
147 }
148
149 /// Translates a list of types.
translateTypes(ArrayRef<LLVM::LLVMType> types,SmallVectorImpl<llvm::Type * > & result)150 void translateTypes(ArrayRef<LLVM::LLVMType> types,
151 SmallVectorImpl<llvm::Type *> &result) {
152 result.reserve(result.size() + types.size());
153 for (auto type : types)
154 result.push_back(translateType(type));
155 }
156
157 /// Reference to the context in which the LLVM IR types are created.
158 llvm::LLVMContext &context;
159
160 /// Map of known translation. This serves a double purpose: caches translation
161 /// results to avoid repeated recursive calls and makes sure identified
162 /// structs with the same name (that is, equal) are resolved to an existing
163 /// type instead of creating a new type.
164 llvm::DenseMap<LLVM::LLVMType, llvm::Type *> knownTranslations;
165 };
166 } // end namespace detail
167 } // end namespace LLVM
168 } // end namespace mlir
169
TypeToLLVMIRTranslator(llvm::LLVMContext & context)170 LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context)
171 : impl(new detail::TypeToLLVMIRTranslatorImpl(context)) {}
172
~TypeToLLVMIRTranslator()173 LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() {}
174
translateType(LLVM::LLVMType type)175 llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(LLVM::LLVMType type) {
176 return impl->translateType(type);
177 }
178
getPreferredAlignment(LLVM::LLVMType type,const llvm::DataLayout & layout)179 unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
180 LLVM::LLVMType type, const llvm::DataLayout &layout) {
181 return layout.getPrefTypeAlignment(translateType(type));
182 }
183
184 namespace mlir {
185 namespace LLVM {
186 namespace detail {
187 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
188 class TypeFromLLVMIRTranslatorImpl {
189 public:
190 /// Constructs a class creating types in the given MLIR context.
TypeFromLLVMIRTranslatorImpl(MLIRContext & context)191 TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
192
193 /// Translates the given type.
translateType(llvm::Type * type)194 LLVM::LLVMType translateType(llvm::Type *type) {
195 if (knownTranslations.count(type))
196 return knownTranslations.lookup(type);
197
198 LLVM::LLVMType translated =
199 llvm::TypeSwitch<llvm::Type *, LLVM::LLVMType>(type)
200 .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
201 llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
202 llvm::ScalableVectorType>(
203 [this](auto *type) { return this->translate(type); })
204 .Default([this](llvm::Type *type) {
205 return translatePrimitiveType(type);
206 });
207 knownTranslations.try_emplace(type, translated);
208 return translated;
209 }
210
211 private:
212 /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
213 /// type.
translatePrimitiveType(llvm::Type * type)214 LLVM::LLVMType translatePrimitiveType(llvm::Type *type) {
215 if (type->isVoidTy())
216 return LLVM::LLVMVoidType::get(&context);
217 if (type->isHalfTy())
218 return LLVM::LLVMHalfType::get(&context);
219 if (type->isBFloatTy())
220 return LLVM::LLVMBFloatType::get(&context);
221 if (type->isFloatTy())
222 return LLVM::LLVMFloatType::get(&context);
223 if (type->isDoubleTy())
224 return LLVM::LLVMDoubleType::get(&context);
225 if (type->isFP128Ty())
226 return LLVM::LLVMFP128Type::get(&context);
227 if (type->isX86_FP80Ty())
228 return LLVM::LLVMX86FP80Type::get(&context);
229 if (type->isPPC_FP128Ty())
230 return LLVM::LLVMPPCFP128Type::get(&context);
231 if (type->isX86_MMXTy())
232 return LLVM::LLVMX86MMXType::get(&context);
233 if (type->isLabelTy())
234 return LLVM::LLVMLabelType::get(&context);
235 if (type->isMetadataTy())
236 return LLVM::LLVMMetadataType::get(&context);
237 llvm_unreachable("not a primitive type");
238 }
239
240 /// Translates the given array type.
translate(llvm::ArrayType * type)241 LLVM::LLVMType translate(llvm::ArrayType *type) {
242 return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
243 type->getNumElements());
244 }
245
246 /// Translates the given function type.
translate(llvm::FunctionType * type)247 LLVM::LLVMType translate(llvm::FunctionType *type) {
248 SmallVector<LLVM::LLVMType, 8> paramTypes;
249 translateTypes(type->params(), paramTypes);
250 return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
251 paramTypes, type->isVarArg());
252 }
253
254 /// Translates the given integer type.
translate(llvm::IntegerType * type)255 LLVM::LLVMType translate(llvm::IntegerType *type) {
256 return LLVM::LLVMIntegerType::get(&context, type->getBitWidth());
257 }
258
259 /// Translates the given pointer type.
translate(llvm::PointerType * type)260 LLVM::LLVMType translate(llvm::PointerType *type) {
261 return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
262 type->getAddressSpace());
263 }
264
265 /// Translates the given structure type.
translate(llvm::StructType * type)266 LLVM::LLVMType translate(llvm::StructType *type) {
267 SmallVector<LLVM::LLVMType, 8> subtypes;
268 if (type->isLiteral()) {
269 translateTypes(type->subtypes(), subtypes);
270 return LLVM::LLVMStructType::getLiteral(&context, subtypes,
271 type->isPacked());
272 }
273
274 if (type->isOpaque())
275 return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
276
277 LLVM::LLVMStructType translated =
278 LLVM::LLVMStructType::getIdentified(&context, type->getName());
279 knownTranslations.try_emplace(type, translated);
280 translateTypes(type->subtypes(), subtypes);
281 LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
282 assert(succeeded(bodySet) &&
283 "could not set the body of an identified struct");
284 (void)bodySet;
285 return translated;
286 }
287
288 /// Translates the given fixed-vector type.
translate(llvm::FixedVectorType * type)289 LLVM::LLVMType translate(llvm::FixedVectorType *type) {
290 return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()),
291 type->getNumElements());
292 }
293
294 /// Translates the given scalable-vector type.
translate(llvm::ScalableVectorType * type)295 LLVM::LLVMType translate(llvm::ScalableVectorType *type) {
296 return LLVM::LLVMScalableVectorType::get(
297 translateType(type->getElementType()), type->getMinNumElements());
298 }
299
300 /// Translates a list of types.
translateTypes(ArrayRef<llvm::Type * > types,SmallVectorImpl<LLVM::LLVMType> & result)301 void translateTypes(ArrayRef<llvm::Type *> types,
302 SmallVectorImpl<LLVM::LLVMType> &result) {
303 result.reserve(result.size() + types.size());
304 for (llvm::Type *type : types)
305 result.push_back(translateType(type));
306 }
307
308 /// Map of known translations. Serves as a cache and as recursion stopper for
309 /// translating recursive structs.
310 llvm::DenseMap<llvm::Type *, LLVM::LLVMType> knownTranslations;
311
312 /// The context in which MLIR types are created.
313 MLIRContext &context;
314 };
315 } // end namespace detail
316 } // end namespace LLVM
317 } // end namespace mlir
318
TypeFromLLVMIRTranslator(MLIRContext & context)319 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
320 : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
321
~TypeFromLLVMIRTranslator()322 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
323
translateType(llvm::Type * type)324 LLVM::LLVMType LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
325 return impl->translateType(type);
326 }
327