1 //===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/BuiltinTypes.h"
10 #include "mlir-c/AffineMap.h"
11 #include "mlir-c/IR.h"
12 #include "mlir/CAPI/AffineMap.h"
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Types.h"
17
18 using namespace mlir;
19
20 //===----------------------------------------------------------------------===//
21 // Integer types.
22 //===----------------------------------------------------------------------===//
23
mlirTypeIsAInteger(MlirType type)24 bool mlirTypeIsAInteger(MlirType type) {
25 return unwrap(type).isa<IntegerType>();
26 }
27
mlirIntegerTypeGet(MlirContext ctx,unsigned bitwidth)28 MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
29 return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
30 }
31
mlirIntegerTypeSignedGet(MlirContext ctx,unsigned bitwidth)32 MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
33 return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
34 }
35
mlirIntegerTypeUnsignedGet(MlirContext ctx,unsigned bitwidth)36 MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
37 return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
38 }
39
mlirIntegerTypeGetWidth(MlirType type)40 unsigned mlirIntegerTypeGetWidth(MlirType type) {
41 return unwrap(type).cast<IntegerType>().getWidth();
42 }
43
mlirIntegerTypeIsSignless(MlirType type)44 bool mlirIntegerTypeIsSignless(MlirType type) {
45 return unwrap(type).cast<IntegerType>().isSignless();
46 }
47
mlirIntegerTypeIsSigned(MlirType type)48 bool mlirIntegerTypeIsSigned(MlirType type) {
49 return unwrap(type).cast<IntegerType>().isSigned();
50 }
51
mlirIntegerTypeIsUnsigned(MlirType type)52 bool mlirIntegerTypeIsUnsigned(MlirType type) {
53 return unwrap(type).cast<IntegerType>().isUnsigned();
54 }
55
56 //===----------------------------------------------------------------------===//
57 // Index type.
58 //===----------------------------------------------------------------------===//
59
mlirTypeIsAIndex(MlirType type)60 bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
61
mlirIndexTypeGet(MlirContext ctx)62 MlirType mlirIndexTypeGet(MlirContext ctx) {
63 return wrap(IndexType::get(unwrap(ctx)));
64 }
65
66 //===----------------------------------------------------------------------===//
67 // Floating-point types.
68 //===----------------------------------------------------------------------===//
69
mlirTypeIsABF16(MlirType type)70 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
71
mlirBF16TypeGet(MlirContext ctx)72 MlirType mlirBF16TypeGet(MlirContext ctx) {
73 return wrap(FloatType::getBF16(unwrap(ctx)));
74 }
75
mlirTypeIsAF16(MlirType type)76 bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
77
mlirF16TypeGet(MlirContext ctx)78 MlirType mlirF16TypeGet(MlirContext ctx) {
79 return wrap(FloatType::getF16(unwrap(ctx)));
80 }
81
mlirTypeIsAF32(MlirType type)82 bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
83
mlirF32TypeGet(MlirContext ctx)84 MlirType mlirF32TypeGet(MlirContext ctx) {
85 return wrap(FloatType::getF32(unwrap(ctx)));
86 }
87
mlirTypeIsAF64(MlirType type)88 bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
89
mlirF64TypeGet(MlirContext ctx)90 MlirType mlirF64TypeGet(MlirContext ctx) {
91 return wrap(FloatType::getF64(unwrap(ctx)));
92 }
93
94 //===----------------------------------------------------------------------===//
95 // None type.
96 //===----------------------------------------------------------------------===//
97
mlirTypeIsANone(MlirType type)98 bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
99
mlirNoneTypeGet(MlirContext ctx)100 MlirType mlirNoneTypeGet(MlirContext ctx) {
101 return wrap(NoneType::get(unwrap(ctx)));
102 }
103
104 //===----------------------------------------------------------------------===//
105 // Complex type.
106 //===----------------------------------------------------------------------===//
107
mlirTypeIsAComplex(MlirType type)108 bool mlirTypeIsAComplex(MlirType type) {
109 return unwrap(type).isa<ComplexType>();
110 }
111
mlirComplexTypeGet(MlirType elementType)112 MlirType mlirComplexTypeGet(MlirType elementType) {
113 return wrap(ComplexType::get(unwrap(elementType)));
114 }
115
mlirComplexTypeGetElementType(MlirType type)116 MlirType mlirComplexTypeGetElementType(MlirType type) {
117 return wrap(unwrap(type).cast<ComplexType>().getElementType());
118 }
119
120 //===----------------------------------------------------------------------===//
121 // Shaped type.
122 //===----------------------------------------------------------------------===//
123
mlirTypeIsAShaped(MlirType type)124 bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
125
mlirShapedTypeGetElementType(MlirType type)126 MlirType mlirShapedTypeGetElementType(MlirType type) {
127 return wrap(unwrap(type).cast<ShapedType>().getElementType());
128 }
129
mlirShapedTypeHasRank(MlirType type)130 bool mlirShapedTypeHasRank(MlirType type) {
131 return unwrap(type).cast<ShapedType>().hasRank();
132 }
133
mlirShapedTypeGetRank(MlirType type)134 int64_t mlirShapedTypeGetRank(MlirType type) {
135 return unwrap(type).cast<ShapedType>().getRank();
136 }
137
mlirShapedTypeHasStaticShape(MlirType type)138 bool mlirShapedTypeHasStaticShape(MlirType type) {
139 return unwrap(type).cast<ShapedType>().hasStaticShape();
140 }
141
mlirShapedTypeIsDynamicDim(MlirType type,intptr_t dim)142 bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
143 return unwrap(type).cast<ShapedType>().isDynamicDim(
144 static_cast<unsigned>(dim));
145 }
146
mlirShapedTypeGetDimSize(MlirType type,intptr_t dim)147 int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
148 return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
149 }
150
mlirShapedTypeIsDynamicSize(int64_t size)151 bool mlirShapedTypeIsDynamicSize(int64_t size) {
152 return ShapedType::isDynamic(size);
153 }
154
mlirShapedTypeIsDynamicStrideOrOffset(int64_t val)155 bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
156 return ShapedType::isDynamicStrideOrOffset(val);
157 }
158
159 //===----------------------------------------------------------------------===//
160 // Vector type.
161 //===----------------------------------------------------------------------===//
162
mlirTypeIsAVector(MlirType type)163 bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
164
mlirVectorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)165 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
166 MlirType elementType) {
167 return wrap(
168 VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
169 unwrap(elementType)));
170 }
171
mlirVectorTypeGetChecked(intptr_t rank,const int64_t * shape,MlirType elementType,MlirLocation loc)172 MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
173 MlirType elementType, MlirLocation loc) {
174 return wrap(VectorType::getChecked(
175 llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
176 unwrap(loc)));
177 }
178
179 //===----------------------------------------------------------------------===//
180 // Ranked / Unranked tensor type.
181 //===----------------------------------------------------------------------===//
182
mlirTypeIsATensor(MlirType type)183 bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
184
mlirTypeIsARankedTensor(MlirType type)185 bool mlirTypeIsARankedTensor(MlirType type) {
186 return unwrap(type).isa<RankedTensorType>();
187 }
188
mlirTypeIsAUnrankedTensor(MlirType type)189 bool mlirTypeIsAUnrankedTensor(MlirType type) {
190 return unwrap(type).isa<UnrankedTensorType>();
191 }
192
mlirRankedTensorTypeGet(intptr_t rank,const int64_t * shape,MlirType elementType)193 MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
194 MlirType elementType) {
195 return wrap(RankedTensorType::get(
196 llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
197 unwrap(elementType)));
198 }
199
mlirRankedTensorTypeGetChecked(intptr_t rank,const int64_t * shape,MlirType elementType,MlirLocation loc)200 MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
201 MlirType elementType,
202 MlirLocation loc) {
203 return wrap(RankedTensorType::getChecked(
204 llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
205 unwrap(loc)));
206 }
207
mlirUnrankedTensorTypeGet(MlirType elementType)208 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
209 return wrap(UnrankedTensorType::get(unwrap(elementType)));
210 }
211
mlirUnrankedTensorTypeGetChecked(MlirType elementType,MlirLocation loc)212 MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
213 MlirLocation loc) {
214 return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc)));
215 }
216
217 //===----------------------------------------------------------------------===//
218 // Ranked / Unranked MemRef type.
219 //===----------------------------------------------------------------------===//
220
mlirTypeIsAMemRef(MlirType type)221 bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
222
mlirMemRefTypeGet(MlirType elementType,intptr_t rank,const int64_t * shape,intptr_t numMaps,MlirAffineMap const * affineMaps,unsigned memorySpace)223 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
224 const int64_t *shape, intptr_t numMaps,
225 MlirAffineMap const *affineMaps,
226 unsigned memorySpace) {
227 SmallVector<AffineMap, 1> maps;
228 (void)unwrapList(numMaps, affineMaps, maps);
229 return wrap(
230 MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
231 unwrap(elementType), maps, memorySpace));
232 }
233
mlirMemRefTypeContiguousGet(MlirType elementType,intptr_t rank,const int64_t * shape,unsigned memorySpace)234 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
235 const int64_t *shape,
236 unsigned memorySpace) {
237 return wrap(
238 MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
239 unwrap(elementType), llvm::None, memorySpace));
240 }
241
mlirMemRefTypeContiguousGetChecked(MlirType elementType,intptr_t rank,const int64_t * shape,unsigned memorySpace,MlirLocation loc)242 MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
243 const int64_t *shape,
244 unsigned memorySpace,
245 MlirLocation loc) {
246 return wrap(MemRefType::getChecked(
247 llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
248 llvm::None, memorySpace, unwrap(loc)));
249 }
250
mlirMemRefTypeGetNumAffineMaps(MlirType type)251 intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
252 return static_cast<intptr_t>(
253 unwrap(type).cast<MemRefType>().getAffineMaps().size());
254 }
255
mlirMemRefTypeGetAffineMap(MlirType type,intptr_t pos)256 MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
257 return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
258 }
259
mlirMemRefTypeGetMemorySpace(MlirType type)260 unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
261 return unwrap(type).cast<MemRefType>().getMemorySpace();
262 }
263
mlirTypeIsAUnrankedMemRef(MlirType type)264 bool mlirTypeIsAUnrankedMemRef(MlirType type) {
265 return unwrap(type).isa<UnrankedMemRefType>();
266 }
267
mlirUnrankedMemRefTypeGet(MlirType elementType,unsigned memorySpace)268 MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
269 return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
270 }
271
mlirUnrankedMemRefTypeGetChecked(MlirType elementType,unsigned memorySpace,MlirLocation loc)272 MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
273 unsigned memorySpace,
274 MlirLocation loc) {
275 return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace,
276 unwrap(loc)));
277 }
278
mlirUnrankedMemrefGetMemorySpace(MlirType type)279 unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
280 return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
281 }
282
283 //===----------------------------------------------------------------------===//
284 // Tuple type.
285 //===----------------------------------------------------------------------===//
286
mlirTypeIsATuple(MlirType type)287 bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
288
mlirTupleTypeGet(MlirContext ctx,intptr_t numElements,MlirType const * elements)289 MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
290 MlirType const *elements) {
291 SmallVector<Type, 4> types;
292 ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
293 return wrap(TupleType::get(typeRef, unwrap(ctx)));
294 }
295
mlirTupleTypeGetNumTypes(MlirType type)296 intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
297 return unwrap(type).cast<TupleType>().size();
298 }
299
mlirTupleTypeGetType(MlirType type,intptr_t pos)300 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
301 return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
302 }
303
304 //===----------------------------------------------------------------------===//
305 // Function type.
306 //===----------------------------------------------------------------------===//
307
mlirTypeIsAFunction(MlirType type)308 bool mlirTypeIsAFunction(MlirType type) {
309 return unwrap(type).isa<FunctionType>();
310 }
311
mlirFunctionTypeGet(MlirContext ctx,intptr_t numInputs,MlirType const * inputs,intptr_t numResults,MlirType const * results)312 MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
313 MlirType const *inputs, intptr_t numResults,
314 MlirType const *results) {
315 SmallVector<Type, 4> inputsList;
316 SmallVector<Type, 4> resultsList;
317 (void)unwrapList(numInputs, inputs, inputsList);
318 (void)unwrapList(numResults, results, resultsList);
319 return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx)));
320 }
321
mlirFunctionTypeGetNumInputs(MlirType type)322 intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
323 return unwrap(type).cast<FunctionType>().getNumInputs();
324 }
325
mlirFunctionTypeGetNumResults(MlirType type)326 intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
327 return unwrap(type).cast<FunctionType>().getNumResults();
328 }
329
mlirFunctionTypeGetInput(MlirType type,intptr_t pos)330 MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
331 assert(pos >= 0 && "pos in array must be positive");
332 return wrap(
333 unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
334 }
335
mlirFunctionTypeGetResult(MlirType type,intptr_t pos)336 MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
337 assert(pos >= 0 && "pos in array must be positive");
338 return wrap(
339 unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
340 }
341