1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir/CAPI/AffineMap.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15
16 using namespace mlir;
17
18 //===----------------------------------------------------------------------===//
19 // Affine map attribute.
20 //===----------------------------------------------------------------------===//
21
mlirAttributeIsAAffineMap(MlirAttribute attr)22 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
23 return unwrap(attr).isa<AffineMapAttr>();
24 }
25
mlirAffineMapAttrGet(MlirAffineMap map)26 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
27 return wrap(AffineMapAttr::get(unwrap(map)));
28 }
29
mlirAffineMapAttrGetValue(MlirAttribute attr)30 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
31 return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
32 }
33
34 //===----------------------------------------------------------------------===//
35 // Array attribute.
36 //===----------------------------------------------------------------------===//
37
mlirAttributeIsAArray(MlirAttribute attr)38 bool mlirAttributeIsAArray(MlirAttribute attr) {
39 return unwrap(attr).isa<ArrayAttr>();
40 }
41
mlirArrayAttrGet(MlirContext ctx,intptr_t numElements,MlirAttribute const * elements)42 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
43 MlirAttribute const *elements) {
44 SmallVector<Attribute, 8> attrs;
45 return wrap(ArrayAttr::get(
46 unwrapList(static_cast<size_t>(numElements), elements, attrs),
47 unwrap(ctx)));
48 }
49
mlirArrayAttrGetNumElements(MlirAttribute attr)50 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
51 return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
52 }
53
mlirArrayAttrGetElement(MlirAttribute attr,intptr_t pos)54 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
55 return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
56 }
57
58 //===----------------------------------------------------------------------===//
59 // Dictionary attribute.
60 //===----------------------------------------------------------------------===//
61
mlirAttributeIsADictionary(MlirAttribute attr)62 bool mlirAttributeIsADictionary(MlirAttribute attr) {
63 return unwrap(attr).isa<DictionaryAttr>();
64 }
65
mlirDictionaryAttrGet(MlirContext ctx,intptr_t numElements,MlirNamedAttribute const * elements)66 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
67 MlirNamedAttribute const *elements) {
68 SmallVector<NamedAttribute, 8> attributes;
69 attributes.reserve(numElements);
70 for (intptr_t i = 0; i < numElements; ++i)
71 attributes.emplace_back(
72 Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
73 unwrap(elements[i].attribute));
74 return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
75 }
76
mlirDictionaryAttrGetNumElements(MlirAttribute attr)77 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
78 return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
79 }
80
mlirDictionaryAttrGetElement(MlirAttribute attr,intptr_t pos)81 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
82 intptr_t pos) {
83 NamedAttribute attribute =
84 unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
85 return {wrap(attribute.first.strref()), wrap(attribute.second)};
86 }
87
mlirDictionaryAttrGetElementByName(MlirAttribute attr,MlirStringRef name)88 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
89 MlirStringRef name) {
90 return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
91 }
92
93 //===----------------------------------------------------------------------===//
94 // Floating point attribute.
95 //===----------------------------------------------------------------------===//
96
mlirAttributeIsAFloat(MlirAttribute attr)97 bool mlirAttributeIsAFloat(MlirAttribute attr) {
98 return unwrap(attr).isa<FloatAttr>();
99 }
100
mlirFloatAttrDoubleGet(MlirContext ctx,MlirType type,double value)101 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
102 double value) {
103 return wrap(FloatAttr::get(unwrap(type), value));
104 }
105
mlirFloatAttrDoubleGetChecked(MlirType type,double value,MlirLocation loc)106 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
107 MlirLocation loc) {
108 return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc)));
109 }
110
mlirFloatAttrGetValueDouble(MlirAttribute attr)111 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
112 return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
113 }
114
115 //===----------------------------------------------------------------------===//
116 // Integer attribute.
117 //===----------------------------------------------------------------------===//
118
mlirAttributeIsAInteger(MlirAttribute attr)119 bool mlirAttributeIsAInteger(MlirAttribute attr) {
120 return unwrap(attr).isa<IntegerAttr>();
121 }
122
mlirIntegerAttrGet(MlirType type,int64_t value)123 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
124 return wrap(IntegerAttr::get(unwrap(type), value));
125 }
126
mlirIntegerAttrGetValueInt(MlirAttribute attr)127 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
128 return unwrap(attr).cast<IntegerAttr>().getInt();
129 }
130
131 //===----------------------------------------------------------------------===//
132 // Bool attribute.
133 //===----------------------------------------------------------------------===//
134
mlirAttributeIsABool(MlirAttribute attr)135 bool mlirAttributeIsABool(MlirAttribute attr) {
136 return unwrap(attr).isa<BoolAttr>();
137 }
138
mlirBoolAttrGet(MlirContext ctx,int value)139 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
140 return wrap(BoolAttr::get(value, unwrap(ctx)));
141 }
142
mlirBoolAttrGetValue(MlirAttribute attr)143 bool mlirBoolAttrGetValue(MlirAttribute attr) {
144 return unwrap(attr).cast<BoolAttr>().getValue();
145 }
146
147 //===----------------------------------------------------------------------===//
148 // Integer set attribute.
149 //===----------------------------------------------------------------------===//
150
mlirAttributeIsAIntegerSet(MlirAttribute attr)151 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
152 return unwrap(attr).isa<IntegerSetAttr>();
153 }
154
155 //===----------------------------------------------------------------------===//
156 // Opaque attribute.
157 //===----------------------------------------------------------------------===//
158
mlirAttributeIsAOpaque(MlirAttribute attr)159 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
160 return unwrap(attr).isa<OpaqueAttr>();
161 }
162
mlirOpaqueAttrGet(MlirContext ctx,MlirStringRef dialectNamespace,intptr_t dataLength,const char * data,MlirType type)163 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
164 intptr_t dataLength, const char *data,
165 MlirType type) {
166 return wrap(
167 OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
168 StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
169 }
170
mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr)171 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
172 return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
173 }
174
mlirOpaqueAttrGetData(MlirAttribute attr)175 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
176 return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData());
177 }
178
179 //===----------------------------------------------------------------------===//
180 // String attribute.
181 //===----------------------------------------------------------------------===//
182
mlirAttributeIsAString(MlirAttribute attr)183 bool mlirAttributeIsAString(MlirAttribute attr) {
184 return unwrap(attr).isa<StringAttr>();
185 }
186
mlirStringAttrGet(MlirContext ctx,MlirStringRef str)187 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
188 return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
189 }
190
mlirStringAttrTypedGet(MlirType type,MlirStringRef str)191 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
192 return wrap(StringAttr::get(unwrap(str), unwrap(type)));
193 }
194
mlirStringAttrGetValue(MlirAttribute attr)195 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
196 return wrap(unwrap(attr).cast<StringAttr>().getValue());
197 }
198
199 //===----------------------------------------------------------------------===//
200 // SymbolRef attribute.
201 //===----------------------------------------------------------------------===//
202
mlirAttributeIsASymbolRef(MlirAttribute attr)203 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
204 return unwrap(attr).isa<SymbolRefAttr>();
205 }
206
mlirSymbolRefAttrGet(MlirContext ctx,MlirStringRef symbol,intptr_t numReferences,MlirAttribute const * references)207 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
208 intptr_t numReferences,
209 MlirAttribute const *references) {
210 SmallVector<FlatSymbolRefAttr, 4> refs;
211 refs.reserve(numReferences);
212 for (intptr_t i = 0; i < numReferences; ++i)
213 refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
214 return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
215 }
216
mlirSymbolRefAttrGetRootReference(MlirAttribute attr)217 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
218 return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference());
219 }
220
mlirSymbolRefAttrGetLeafReference(MlirAttribute attr)221 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
222 return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference());
223 }
224
mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr)225 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
226 return static_cast<intptr_t>(
227 unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
228 }
229
mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,intptr_t pos)230 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
231 intptr_t pos) {
232 return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
233 }
234
235 //===----------------------------------------------------------------------===//
236 // Flat SymbolRef attribute.
237 //===----------------------------------------------------------------------===//
238
mlirAttributeIsAFlatSymbolRef(MlirAttribute attr)239 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
240 return unwrap(attr).isa<FlatSymbolRefAttr>();
241 }
242
mlirFlatSymbolRefAttrGet(MlirContext ctx,MlirStringRef symbol)243 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
244 return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
245 }
246
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr)247 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
248 return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue());
249 }
250
251 //===----------------------------------------------------------------------===//
252 // Type attribute.
253 //===----------------------------------------------------------------------===//
254
mlirAttributeIsAType(MlirAttribute attr)255 bool mlirAttributeIsAType(MlirAttribute attr) {
256 return unwrap(attr).isa<TypeAttr>();
257 }
258
mlirTypeAttrGet(MlirType type)259 MlirAttribute mlirTypeAttrGet(MlirType type) {
260 return wrap(TypeAttr::get(unwrap(type)));
261 }
262
mlirTypeAttrGetValue(MlirAttribute attr)263 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
264 return wrap(unwrap(attr).cast<TypeAttr>().getValue());
265 }
266
267 //===----------------------------------------------------------------------===//
268 // Unit attribute.
269 //===----------------------------------------------------------------------===//
270
mlirAttributeIsAUnit(MlirAttribute attr)271 bool mlirAttributeIsAUnit(MlirAttribute attr) {
272 return unwrap(attr).isa<UnitAttr>();
273 }
274
mlirUnitAttrGet(MlirContext ctx)275 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
276 return wrap(UnitAttr::get(unwrap(ctx)));
277 }
278
279 //===----------------------------------------------------------------------===//
280 // Elements attributes.
281 //===----------------------------------------------------------------------===//
282
mlirAttributeIsAElements(MlirAttribute attr)283 bool mlirAttributeIsAElements(MlirAttribute attr) {
284 return unwrap(attr).isa<ElementsAttr>();
285 }
286
mlirElementsAttrGetValue(MlirAttribute attr,intptr_t rank,uint64_t * idxs)287 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
288 uint64_t *idxs) {
289 return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
290 llvm::makeArrayRef(idxs, rank)));
291 }
292
mlirElementsAttrIsValidIndex(MlirAttribute attr,intptr_t rank,uint64_t * idxs)293 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
294 uint64_t *idxs) {
295 return unwrap(attr).cast<ElementsAttr>().isValidIndex(
296 llvm::makeArrayRef(idxs, rank));
297 }
298
mlirElementsAttrGetNumElements(MlirAttribute attr)299 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
300 return unwrap(attr).cast<ElementsAttr>().getNumElements();
301 }
302
303 //===----------------------------------------------------------------------===//
304 // Dense elements attribute.
305 //===----------------------------------------------------------------------===//
306
307 //===----------------------------------------------------------------------===//
308 // IsA support.
309
mlirAttributeIsADenseElements(MlirAttribute attr)310 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
311 return unwrap(attr).isa<DenseElementsAttr>();
312 }
mlirAttributeIsADenseIntElements(MlirAttribute attr)313 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
314 return unwrap(attr).isa<DenseIntElementsAttr>();
315 }
mlirAttributeIsADenseFPElements(MlirAttribute attr)316 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
317 return unwrap(attr).isa<DenseFPElementsAttr>();
318 }
319
320 //===----------------------------------------------------------------------===//
321 // Constructors.
322
mlirDenseElementsAttrGet(MlirType shapedType,intptr_t numElements,MlirAttribute const * elements)323 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
324 intptr_t numElements,
325 MlirAttribute const *elements) {
326 SmallVector<Attribute, 8> attributes;
327 return wrap(
328 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
329 unwrapList(numElements, elements, attributes)));
330 }
331
mlirDenseElementsAttrSplatGet(MlirType shapedType,MlirAttribute element)332 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
333 MlirAttribute element) {
334 return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
335 unwrap(element)));
336 }
mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,bool element)337 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
338 bool element) {
339 return wrap(
340 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
341 }
mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,uint32_t element)342 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
343 uint32_t element) {
344 return wrap(
345 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
346 }
mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,int32_t element)347 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
348 int32_t element) {
349 return wrap(
350 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
351 }
mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,uint64_t element)352 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
353 uint64_t element) {
354 return wrap(
355 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
356 }
mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,int64_t element)357 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
358 int64_t element) {
359 return wrap(
360 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
361 }
mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,float element)362 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
363 float element) {
364 return wrap(
365 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
366 }
mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,double element)367 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
368 double element) {
369 return wrap(
370 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
371 }
372
mlirDenseElementsAttrBoolGet(MlirType shapedType,intptr_t numElements,const int * elements)373 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
374 intptr_t numElements,
375 const int *elements) {
376 SmallVector<bool, 8> values(elements, elements + numElements);
377 return wrap(
378 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
379 }
380
381 /// Creates a dense attribute with elements of the type deduced by templates.
382 template <typename T>
getDenseAttribute(MlirType shapedType,intptr_t numElements,const T * elements)383 static MlirAttribute getDenseAttribute(MlirType shapedType,
384 intptr_t numElements,
385 const T *elements) {
386 return wrap(
387 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
388 llvm::makeArrayRef(elements, numElements)));
389 }
390
mlirDenseElementsAttrUInt32Get(MlirType shapedType,intptr_t numElements,const uint32_t * elements)391 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
392 intptr_t numElements,
393 const uint32_t *elements) {
394 return getDenseAttribute(shapedType, numElements, elements);
395 }
mlirDenseElementsAttrInt32Get(MlirType shapedType,intptr_t numElements,const int32_t * elements)396 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
397 intptr_t numElements,
398 const int32_t *elements) {
399 return getDenseAttribute(shapedType, numElements, elements);
400 }
mlirDenseElementsAttrUInt64Get(MlirType shapedType,intptr_t numElements,const uint64_t * elements)401 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
402 intptr_t numElements,
403 const uint64_t *elements) {
404 return getDenseAttribute(shapedType, numElements, elements);
405 }
mlirDenseElementsAttrInt64Get(MlirType shapedType,intptr_t numElements,const int64_t * elements)406 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
407 intptr_t numElements,
408 const int64_t *elements) {
409 return getDenseAttribute(shapedType, numElements, elements);
410 }
mlirDenseElementsAttrFloatGet(MlirType shapedType,intptr_t numElements,const float * elements)411 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
412 intptr_t numElements,
413 const float *elements) {
414 return getDenseAttribute(shapedType, numElements, elements);
415 }
mlirDenseElementsAttrDoubleGet(MlirType shapedType,intptr_t numElements,const double * elements)416 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
417 intptr_t numElements,
418 const double *elements) {
419 return getDenseAttribute(shapedType, numElements, elements);
420 }
421
mlirDenseElementsAttrStringGet(MlirType shapedType,intptr_t numElements,MlirStringRef * strs)422 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
423 intptr_t numElements,
424 MlirStringRef *strs) {
425 SmallVector<StringRef, 8> values;
426 values.reserve(numElements);
427 for (intptr_t i = 0; i < numElements; ++i)
428 values.push_back(unwrap(strs[i]));
429
430 return wrap(
431 DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
432 }
433
mlirDenseElementsAttrReshapeGet(MlirAttribute attr,MlirType shapedType)434 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
435 MlirType shapedType) {
436 return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
437 unwrap(shapedType).cast<ShapedType>()));
438 }
439
440 //===----------------------------------------------------------------------===//
441 // Splat accessors.
442
mlirDenseElementsAttrIsSplat(MlirAttribute attr)443 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
444 return unwrap(attr).cast<DenseElementsAttr>().isSplat();
445 }
446
mlirDenseElementsAttrGetSplatValue(MlirAttribute attr)447 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
448 return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
449 }
mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr)450 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
451 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
452 }
mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr)453 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
454 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
455 }
mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr)456 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
457 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
458 }
mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr)459 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
460 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
461 }
mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr)462 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
463 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
464 }
mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr)465 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
466 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
467 }
mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr)468 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
469 return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
470 }
mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr)471 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
472 return wrap(
473 unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>());
474 }
475
476 //===----------------------------------------------------------------------===//
477 // Indexed accessors.
478
mlirDenseElementsAttrGetBoolValue(MlirAttribute attr,intptr_t pos)479 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
480 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
481 pos);
482 }
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr,intptr_t pos)483 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
484 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
485 pos);
486 }
mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr,intptr_t pos)487 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
488 return *(
489 unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
490 pos);
491 }
mlirDenseElementsAttrGetInt64Value(MlirAttribute attr,intptr_t pos)492 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
493 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
494 pos);
495 }
mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr,intptr_t pos)496 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
497 return *(
498 unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
499 pos);
500 }
mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,intptr_t pos)501 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
502 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
503 pos);
504 }
mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr,intptr_t pos)505 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
506 return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
507 pos);
508 }
mlirDenseElementsAttrGetStringValue(MlirAttribute attr,intptr_t pos)509 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
510 intptr_t pos) {
511 return wrap(
512 *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
513 pos));
514 }
515
516 //===----------------------------------------------------------------------===//
517 // Raw data accessors.
518
mlirDenseElementsAttrGetRawData(MlirAttribute attr)519 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
520 return static_cast<const void *>(
521 unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
522 }
523
524 //===----------------------------------------------------------------------===//
525 // Opaque elements attribute.
526 //===----------------------------------------------------------------------===//
527
mlirAttributeIsAOpaqueElements(MlirAttribute attr)528 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
529 return unwrap(attr).isa<OpaqueElementsAttr>();
530 }
531
532 //===----------------------------------------------------------------------===//
533 // Sparse elements attribute.
534 //===----------------------------------------------------------------------===//
535
mlirAttributeIsASparseElements(MlirAttribute attr)536 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
537 return unwrap(attr).isa<SparseElementsAttr>();
538 }
539
mlirSparseElementsAttribute(MlirType shapedType,MlirAttribute denseIndices,MlirAttribute denseValues)540 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
541 MlirAttribute denseIndices,
542 MlirAttribute denseValues) {
543 return wrap(
544 SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
545 unwrap(denseIndices).cast<DenseElementsAttr>(),
546 unwrap(denseValues).cast<DenseElementsAttr>()));
547 }
548
mlirSparseElementsAttrGetIndices(MlirAttribute attr)549 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
550 return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
551 }
552
mlirSparseElementsAttrGetValues(MlirAttribute attr)553 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
554 return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
555 }
556