• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/BuiltinAttributes.h"
10 #include "mlir/CAPI/AffineMap.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/IR/Attributes.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // Affine map attribute.
20 //===----------------------------------------------------------------------===//
21 
mlirAttributeIsAAffineMap(MlirAttribute attr)22 bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
23   return unwrap(attr).isa<AffineMapAttr>();
24 }
25 
mlirAffineMapAttrGet(MlirAffineMap map)26 MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
27   return wrap(AffineMapAttr::get(unwrap(map)));
28 }
29 
mlirAffineMapAttrGetValue(MlirAttribute attr)30 MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
31   return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // Array attribute.
36 //===----------------------------------------------------------------------===//
37 
mlirAttributeIsAArray(MlirAttribute attr)38 bool mlirAttributeIsAArray(MlirAttribute attr) {
39   return unwrap(attr).isa<ArrayAttr>();
40 }
41 
mlirArrayAttrGet(MlirContext ctx,intptr_t numElements,MlirAttribute const * elements)42 MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
43                                MlirAttribute const *elements) {
44   SmallVector<Attribute, 8> attrs;
45   return wrap(ArrayAttr::get(
46       unwrapList(static_cast<size_t>(numElements), elements, attrs),
47       unwrap(ctx)));
48 }
49 
mlirArrayAttrGetNumElements(MlirAttribute attr)50 intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
51   return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
52 }
53 
mlirArrayAttrGetElement(MlirAttribute attr,intptr_t pos)54 MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
55   return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // Dictionary attribute.
60 //===----------------------------------------------------------------------===//
61 
mlirAttributeIsADictionary(MlirAttribute attr)62 bool mlirAttributeIsADictionary(MlirAttribute attr) {
63   return unwrap(attr).isa<DictionaryAttr>();
64 }
65 
mlirDictionaryAttrGet(MlirContext ctx,intptr_t numElements,MlirNamedAttribute const * elements)66 MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
67                                     MlirNamedAttribute const *elements) {
68   SmallVector<NamedAttribute, 8> attributes;
69   attributes.reserve(numElements);
70   for (intptr_t i = 0; i < numElements; ++i)
71     attributes.emplace_back(
72         Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
73         unwrap(elements[i].attribute));
74   return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
75 }
76 
mlirDictionaryAttrGetNumElements(MlirAttribute attr)77 intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
78   return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
79 }
80 
mlirDictionaryAttrGetElement(MlirAttribute attr,intptr_t pos)81 MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
82                                                 intptr_t pos) {
83   NamedAttribute attribute =
84       unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
85   return {wrap(attribute.first.strref()), wrap(attribute.second)};
86 }
87 
mlirDictionaryAttrGetElementByName(MlirAttribute attr,MlirStringRef name)88 MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
89                                                  MlirStringRef name) {
90   return wrap(unwrap(attr).cast<DictionaryAttr>().get(unwrap(name)));
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Floating point attribute.
95 //===----------------------------------------------------------------------===//
96 
mlirAttributeIsAFloat(MlirAttribute attr)97 bool mlirAttributeIsAFloat(MlirAttribute attr) {
98   return unwrap(attr).isa<FloatAttr>();
99 }
100 
mlirFloatAttrDoubleGet(MlirContext ctx,MlirType type,double value)101 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
102                                      double value) {
103   return wrap(FloatAttr::get(unwrap(type), value));
104 }
105 
mlirFloatAttrDoubleGetChecked(MlirType type,double value,MlirLocation loc)106 MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
107                                             MlirLocation loc) {
108   return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc)));
109 }
110 
mlirFloatAttrGetValueDouble(MlirAttribute attr)111 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
112   return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // Integer attribute.
117 //===----------------------------------------------------------------------===//
118 
mlirAttributeIsAInteger(MlirAttribute attr)119 bool mlirAttributeIsAInteger(MlirAttribute attr) {
120   return unwrap(attr).isa<IntegerAttr>();
121 }
122 
mlirIntegerAttrGet(MlirType type,int64_t value)123 MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
124   return wrap(IntegerAttr::get(unwrap(type), value));
125 }
126 
mlirIntegerAttrGetValueInt(MlirAttribute attr)127 int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
128   return unwrap(attr).cast<IntegerAttr>().getInt();
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // Bool attribute.
133 //===----------------------------------------------------------------------===//
134 
mlirAttributeIsABool(MlirAttribute attr)135 bool mlirAttributeIsABool(MlirAttribute attr) {
136   return unwrap(attr).isa<BoolAttr>();
137 }
138 
mlirBoolAttrGet(MlirContext ctx,int value)139 MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
140   return wrap(BoolAttr::get(value, unwrap(ctx)));
141 }
142 
mlirBoolAttrGetValue(MlirAttribute attr)143 bool mlirBoolAttrGetValue(MlirAttribute attr) {
144   return unwrap(attr).cast<BoolAttr>().getValue();
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Integer set attribute.
149 //===----------------------------------------------------------------------===//
150 
mlirAttributeIsAIntegerSet(MlirAttribute attr)151 bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
152   return unwrap(attr).isa<IntegerSetAttr>();
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // Opaque attribute.
157 //===----------------------------------------------------------------------===//
158 
mlirAttributeIsAOpaque(MlirAttribute attr)159 bool mlirAttributeIsAOpaque(MlirAttribute attr) {
160   return unwrap(attr).isa<OpaqueAttr>();
161 }
162 
mlirOpaqueAttrGet(MlirContext ctx,MlirStringRef dialectNamespace,intptr_t dataLength,const char * data,MlirType type)163 MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
164                                 intptr_t dataLength, const char *data,
165                                 MlirType type) {
166   return wrap(
167       OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
168                       StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
169 }
170 
mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr)171 MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
172   return wrap(unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().strref());
173 }
174 
mlirOpaqueAttrGetData(MlirAttribute attr)175 MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
176   return wrap(unwrap(attr).cast<OpaqueAttr>().getAttrData());
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // String attribute.
181 //===----------------------------------------------------------------------===//
182 
mlirAttributeIsAString(MlirAttribute attr)183 bool mlirAttributeIsAString(MlirAttribute attr) {
184   return unwrap(attr).isa<StringAttr>();
185 }
186 
mlirStringAttrGet(MlirContext ctx,MlirStringRef str)187 MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
188   return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
189 }
190 
mlirStringAttrTypedGet(MlirType type,MlirStringRef str)191 MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
192   return wrap(StringAttr::get(unwrap(str), unwrap(type)));
193 }
194 
mlirStringAttrGetValue(MlirAttribute attr)195 MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
196   return wrap(unwrap(attr).cast<StringAttr>().getValue());
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // SymbolRef attribute.
201 //===----------------------------------------------------------------------===//
202 
mlirAttributeIsASymbolRef(MlirAttribute attr)203 bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
204   return unwrap(attr).isa<SymbolRefAttr>();
205 }
206 
mlirSymbolRefAttrGet(MlirContext ctx,MlirStringRef symbol,intptr_t numReferences,MlirAttribute const * references)207 MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
208                                    intptr_t numReferences,
209                                    MlirAttribute const *references) {
210   SmallVector<FlatSymbolRefAttr, 4> refs;
211   refs.reserve(numReferences);
212   for (intptr_t i = 0; i < numReferences; ++i)
213     refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
214   return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
215 }
216 
mlirSymbolRefAttrGetRootReference(MlirAttribute attr)217 MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
218   return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference());
219 }
220 
mlirSymbolRefAttrGetLeafReference(MlirAttribute attr)221 MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
222   return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference());
223 }
224 
mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr)225 intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
226   return static_cast<intptr_t>(
227       unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
228 }
229 
mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,intptr_t pos)230 MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
231                                                   intptr_t pos) {
232   return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // Flat SymbolRef attribute.
237 //===----------------------------------------------------------------------===//
238 
mlirAttributeIsAFlatSymbolRef(MlirAttribute attr)239 bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
240   return unwrap(attr).isa<FlatSymbolRefAttr>();
241 }
242 
mlirFlatSymbolRefAttrGet(MlirContext ctx,MlirStringRef symbol)243 MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
244   return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
245 }
246 
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr)247 MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
248   return wrap(unwrap(attr).cast<FlatSymbolRefAttr>().getValue());
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Type attribute.
253 //===----------------------------------------------------------------------===//
254 
mlirAttributeIsAType(MlirAttribute attr)255 bool mlirAttributeIsAType(MlirAttribute attr) {
256   return unwrap(attr).isa<TypeAttr>();
257 }
258 
mlirTypeAttrGet(MlirType type)259 MlirAttribute mlirTypeAttrGet(MlirType type) {
260   return wrap(TypeAttr::get(unwrap(type)));
261 }
262 
mlirTypeAttrGetValue(MlirAttribute attr)263 MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
264   return wrap(unwrap(attr).cast<TypeAttr>().getValue());
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Unit attribute.
269 //===----------------------------------------------------------------------===//
270 
mlirAttributeIsAUnit(MlirAttribute attr)271 bool mlirAttributeIsAUnit(MlirAttribute attr) {
272   return unwrap(attr).isa<UnitAttr>();
273 }
274 
mlirUnitAttrGet(MlirContext ctx)275 MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
276   return wrap(UnitAttr::get(unwrap(ctx)));
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // Elements attributes.
281 //===----------------------------------------------------------------------===//
282 
mlirAttributeIsAElements(MlirAttribute attr)283 bool mlirAttributeIsAElements(MlirAttribute attr) {
284   return unwrap(attr).isa<ElementsAttr>();
285 }
286 
mlirElementsAttrGetValue(MlirAttribute attr,intptr_t rank,uint64_t * idxs)287 MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
288                                        uint64_t *idxs) {
289   return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
290       llvm::makeArrayRef(idxs, rank)));
291 }
292 
mlirElementsAttrIsValidIndex(MlirAttribute attr,intptr_t rank,uint64_t * idxs)293 bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
294                                   uint64_t *idxs) {
295   return unwrap(attr).cast<ElementsAttr>().isValidIndex(
296       llvm::makeArrayRef(idxs, rank));
297 }
298 
mlirElementsAttrGetNumElements(MlirAttribute attr)299 int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
300   return unwrap(attr).cast<ElementsAttr>().getNumElements();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Dense elements attribute.
305 //===----------------------------------------------------------------------===//
306 
307 //===----------------------------------------------------------------------===//
308 // IsA support.
309 
mlirAttributeIsADenseElements(MlirAttribute attr)310 bool mlirAttributeIsADenseElements(MlirAttribute attr) {
311   return unwrap(attr).isa<DenseElementsAttr>();
312 }
mlirAttributeIsADenseIntElements(MlirAttribute attr)313 bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
314   return unwrap(attr).isa<DenseIntElementsAttr>();
315 }
mlirAttributeIsADenseFPElements(MlirAttribute attr)316 bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
317   return unwrap(attr).isa<DenseFPElementsAttr>();
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // Constructors.
322 
mlirDenseElementsAttrGet(MlirType shapedType,intptr_t numElements,MlirAttribute const * elements)323 MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
324                                        intptr_t numElements,
325                                        MlirAttribute const *elements) {
326   SmallVector<Attribute, 8> attributes;
327   return wrap(
328       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
329                              unwrapList(numElements, elements, attributes)));
330 }
331 
mlirDenseElementsAttrSplatGet(MlirType shapedType,MlirAttribute element)332 MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
333                                             MlirAttribute element) {
334   return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
335                                      unwrap(element)));
336 }
mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,bool element)337 MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
338                                                 bool element) {
339   return wrap(
340       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
341 }
mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,uint32_t element)342 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
343                                                   uint32_t element) {
344   return wrap(
345       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
346 }
mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,int32_t element)347 MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
348                                                  int32_t element) {
349   return wrap(
350       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
351 }
mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,uint64_t element)352 MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
353                                                   uint64_t element) {
354   return wrap(
355       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
356 }
mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,int64_t element)357 MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
358                                                  int64_t element) {
359   return wrap(
360       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
361 }
mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,float element)362 MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
363                                                  float element) {
364   return wrap(
365       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
366 }
mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,double element)367 MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
368                                                   double element) {
369   return wrap(
370       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
371 }
372 
mlirDenseElementsAttrBoolGet(MlirType shapedType,intptr_t numElements,const int * elements)373 MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
374                                            intptr_t numElements,
375                                            const int *elements) {
376   SmallVector<bool, 8> values(elements, elements + numElements);
377   return wrap(
378       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
379 }
380 
381 /// Creates a dense attribute with elements of the type deduced by templates.
382 template <typename T>
getDenseAttribute(MlirType shapedType,intptr_t numElements,const T * elements)383 static MlirAttribute getDenseAttribute(MlirType shapedType,
384                                        intptr_t numElements,
385                                        const T *elements) {
386   return wrap(
387       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
388                              llvm::makeArrayRef(elements, numElements)));
389 }
390 
mlirDenseElementsAttrUInt32Get(MlirType shapedType,intptr_t numElements,const uint32_t * elements)391 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
392                                              intptr_t numElements,
393                                              const uint32_t *elements) {
394   return getDenseAttribute(shapedType, numElements, elements);
395 }
mlirDenseElementsAttrInt32Get(MlirType shapedType,intptr_t numElements,const int32_t * elements)396 MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
397                                             intptr_t numElements,
398                                             const int32_t *elements) {
399   return getDenseAttribute(shapedType, numElements, elements);
400 }
mlirDenseElementsAttrUInt64Get(MlirType shapedType,intptr_t numElements,const uint64_t * elements)401 MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
402                                              intptr_t numElements,
403                                              const uint64_t *elements) {
404   return getDenseAttribute(shapedType, numElements, elements);
405 }
mlirDenseElementsAttrInt64Get(MlirType shapedType,intptr_t numElements,const int64_t * elements)406 MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
407                                             intptr_t numElements,
408                                             const int64_t *elements) {
409   return getDenseAttribute(shapedType, numElements, elements);
410 }
mlirDenseElementsAttrFloatGet(MlirType shapedType,intptr_t numElements,const float * elements)411 MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
412                                             intptr_t numElements,
413                                             const float *elements) {
414   return getDenseAttribute(shapedType, numElements, elements);
415 }
mlirDenseElementsAttrDoubleGet(MlirType shapedType,intptr_t numElements,const double * elements)416 MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
417                                              intptr_t numElements,
418                                              const double *elements) {
419   return getDenseAttribute(shapedType, numElements, elements);
420 }
421 
mlirDenseElementsAttrStringGet(MlirType shapedType,intptr_t numElements,MlirStringRef * strs)422 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
423                                              intptr_t numElements,
424                                              MlirStringRef *strs) {
425   SmallVector<StringRef, 8> values;
426   values.reserve(numElements);
427   for (intptr_t i = 0; i < numElements; ++i)
428     values.push_back(unwrap(strs[i]));
429 
430   return wrap(
431       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
432 }
433 
mlirDenseElementsAttrReshapeGet(MlirAttribute attr,MlirType shapedType)434 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
435                                               MlirType shapedType) {
436   return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
437       unwrap(shapedType).cast<ShapedType>()));
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // Splat accessors.
442 
mlirDenseElementsAttrIsSplat(MlirAttribute attr)443 bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
444   return unwrap(attr).cast<DenseElementsAttr>().isSplat();
445 }
446 
mlirDenseElementsAttrGetSplatValue(MlirAttribute attr)447 MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
448   return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
449 }
mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr)450 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
451   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
452 }
mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr)453 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
454   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
455 }
mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr)456 uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
457   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
458 }
mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr)459 int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
460   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
461 }
mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr)462 uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
463   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
464 }
mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr)465 float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
466   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
467 }
mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr)468 double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
469   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
470 }
mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr)471 MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
472   return wrap(
473       unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>());
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // Indexed accessors.
478 
mlirDenseElementsAttrGetBoolValue(MlirAttribute attr,intptr_t pos)479 bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
480   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
481            pos);
482 }
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr,intptr_t pos)483 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
484   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
485            pos);
486 }
mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr,intptr_t pos)487 uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
488   return *(
489       unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
490       pos);
491 }
mlirDenseElementsAttrGetInt64Value(MlirAttribute attr,intptr_t pos)492 int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
493   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
494            pos);
495 }
mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr,intptr_t pos)496 uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
497   return *(
498       unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
499       pos);
500 }
mlirDenseElementsAttrGetFloatValue(MlirAttribute attr,intptr_t pos)501 float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
502   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
503            pos);
504 }
mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr,intptr_t pos)505 double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
506   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
507            pos);
508 }
mlirDenseElementsAttrGetStringValue(MlirAttribute attr,intptr_t pos)509 MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
510                                                   intptr_t pos) {
511   return wrap(
512       *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
513         pos));
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // Raw data accessors.
518 
mlirDenseElementsAttrGetRawData(MlirAttribute attr)519 const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
520   return static_cast<const void *>(
521       unwrap(attr).cast<DenseElementsAttr>().getRawData().data());
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // Opaque elements attribute.
526 //===----------------------------------------------------------------------===//
527 
mlirAttributeIsAOpaqueElements(MlirAttribute attr)528 bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
529   return unwrap(attr).isa<OpaqueElementsAttr>();
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // Sparse elements attribute.
534 //===----------------------------------------------------------------------===//
535 
mlirAttributeIsASparseElements(MlirAttribute attr)536 bool mlirAttributeIsASparseElements(MlirAttribute attr) {
537   return unwrap(attr).isa<SparseElementsAttr>();
538 }
539 
mlirSparseElementsAttribute(MlirType shapedType,MlirAttribute denseIndices,MlirAttribute denseValues)540 MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
541                                           MlirAttribute denseIndices,
542                                           MlirAttribute denseValues) {
543   return wrap(
544       SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
545                               unwrap(denseIndices).cast<DenseElementsAttr>(),
546                               unwrap(denseValues).cast<DenseElementsAttr>()));
547 }
548 
mlirSparseElementsAttrGetIndices(MlirAttribute attr)549 MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
550   return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
551 }
552 
mlirSparseElementsAttrGetValues(MlirAttribute attr)553 MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
554   return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
555 }
556