1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Types.h"
16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20
21 using namespace mlir;
22 using namespace mlir::detail;
23
24 //===----------------------------------------------------------------------===//
25 // AffineMapAttr
26 //===----------------------------------------------------------------------===//
27
get(AffineMap value)28 AffineMapAttr AffineMapAttr::get(AffineMap value) {
29 return Base::get(value.getContext(), value);
30 }
31
getValue() const32 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
33
34 //===----------------------------------------------------------------------===//
35 // ArrayAttr
36 //===----------------------------------------------------------------------===//
37
get(ArrayRef<Attribute> value,MLIRContext * context)38 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
39 return Base::get(context, value);
40 }
41
getValue() const42 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
43
operator [](unsigned idx) const44 Attribute ArrayAttr::operator[](unsigned idx) const {
45 assert(idx < size() && "index out of bounds");
46 return getValue()[idx];
47 }
48
49 //===----------------------------------------------------------------------===//
50 // DictionaryAttr
51 //===----------------------------------------------------------------------===//
52
53 /// Helper function that does either an in place sort or sorts from source array
54 /// into destination. If inPlace then storage is both the source and the
55 /// destination, else value is the source and storage destination. Returns
56 /// whether source was sorted.
57 template <bool inPlace>
dictionaryAttrSort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)58 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
59 SmallVectorImpl<NamedAttribute> &storage) {
60 // Specialize for the common case.
61 switch (value.size()) {
62 case 0:
63 // Zero already sorted.
64 break;
65 case 1:
66 // One already sorted but may need to be copied.
67 if (!inPlace)
68 storage.assign({value[0]});
69 break;
70 case 2: {
71 bool isSorted = value[0] < value[1];
72 if (inPlace) {
73 if (!isSorted)
74 std::swap(storage[0], storage[1]);
75 } else if (isSorted) {
76 storage.assign({value[0], value[1]});
77 } else {
78 storage.assign({value[1], value[0]});
79 }
80 return !isSorted;
81 }
82 default:
83 if (!inPlace)
84 storage.assign(value.begin(), value.end());
85 // Check to see they are sorted already.
86 bool isSorted = llvm::is_sorted(value);
87 if (!isSorted) {
88 // If not, do a general sort.
89 llvm::array_pod_sort(storage.begin(), storage.end());
90 value = storage;
91 }
92 return !isSorted;
93 }
94 return false;
95 }
96
97 /// Returns an entry with a duplicate name from the given sorted array of named
98 /// attributes. Returns llvm::None if all elements have unique names.
99 static Optional<NamedAttribute>
findDuplicateElement(ArrayRef<NamedAttribute> value)100 findDuplicateElement(ArrayRef<NamedAttribute> value) {
101 const Optional<NamedAttribute> none{llvm::None};
102 if (value.size() < 2)
103 return none;
104
105 if (value.size() == 2)
106 return value[0].first == value[1].first ? value[0] : none;
107
108 auto it = std::adjacent_find(
109 value.begin(), value.end(),
110 [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
111 return it != value.end() ? *it : none;
112 }
113
sort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)114 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
115 SmallVectorImpl<NamedAttribute> &storage) {
116 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
117 assert(!findDuplicateElement(storage) &&
118 "DictionaryAttr element names must be unique");
119 return isSorted;
120 }
121
sortInPlace(SmallVectorImpl<NamedAttribute> & array)122 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
123 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
124 assert(!findDuplicateElement(array) &&
125 "DictionaryAttr element names must be unique");
126 return isSorted;
127 }
128
129 Optional<NamedAttribute>
findDuplicate(SmallVectorImpl<NamedAttribute> & array,bool isSorted)130 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
131 bool isSorted) {
132 if (!isSorted)
133 dictionaryAttrSort</*inPlace=*/true>(array, array);
134 return findDuplicateElement(array);
135 }
136
get(ArrayRef<NamedAttribute> value,MLIRContext * context)137 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
138 MLIRContext *context) {
139 if (value.empty())
140 return DictionaryAttr::getEmpty(context);
141 assert(llvm::all_of(value,
142 [](const NamedAttribute &attr) { return attr.second; }) &&
143 "value cannot have null entries");
144
145 // We need to sort the element list to canonicalize it.
146 SmallVector<NamedAttribute, 8> storage;
147 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
148 value = storage;
149 assert(!findDuplicateElement(value) &&
150 "DictionaryAttr element names must be unique");
151 return Base::get(context, value);
152 }
153 /// Construct a dictionary with an array of values that is known to already be
154 /// sorted by name and uniqued.
getWithSorted(ArrayRef<NamedAttribute> value,MLIRContext * context)155 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
156 MLIRContext *context) {
157 if (value.empty())
158 return DictionaryAttr::getEmpty(context);
159 // Ensure that the attribute elements are unique and sorted.
160 assert(llvm::is_sorted(value,
161 [](NamedAttribute l, NamedAttribute r) {
162 return l.first.strref() < r.first.strref();
163 }) &&
164 "expected attribute values to be sorted");
165 assert(!findDuplicateElement(value) &&
166 "DictionaryAttr element names must be unique");
167 return Base::get(context, value);
168 }
169
getValue() const170 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
171 return getImpl()->getElements();
172 }
173
174 /// Return the specified attribute if present, null otherwise.
get(StringRef name) const175 Attribute DictionaryAttr::get(StringRef name) const {
176 Optional<NamedAttribute> attr = getNamed(name);
177 return attr ? attr->second : nullptr;
178 }
get(Identifier name) const179 Attribute DictionaryAttr::get(Identifier name) const {
180 Optional<NamedAttribute> attr = getNamed(name);
181 return attr ? attr->second : nullptr;
182 }
183
184 /// Return the specified named attribute if present, None otherwise.
getNamed(StringRef name) const185 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
186 ArrayRef<NamedAttribute> values = getValue();
187 const auto *it = llvm::lower_bound(values, name);
188 return it != values.end() && it->first == name ? *it
189 : Optional<NamedAttribute>();
190 }
getNamed(Identifier name) const191 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
192 for (auto elt : getValue())
193 if (elt.first == name)
194 return elt;
195 return llvm::None;
196 }
197
begin() const198 DictionaryAttr::iterator DictionaryAttr::begin() const {
199 return getValue().begin();
200 }
end() const201 DictionaryAttr::iterator DictionaryAttr::end() const {
202 return getValue().end();
203 }
size() const204 size_t DictionaryAttr::size() const { return getValue().size(); }
205
206 //===----------------------------------------------------------------------===//
207 // FloatAttr
208 //===----------------------------------------------------------------------===//
209
get(Type type,double value)210 FloatAttr FloatAttr::get(Type type, double value) {
211 return Base::get(type.getContext(), type, value);
212 }
213
getChecked(Type type,double value,Location loc)214 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
215 return Base::getChecked(loc, type, value);
216 }
217
get(Type type,const APFloat & value)218 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
219 return Base::get(type.getContext(), type, value);
220 }
221
getChecked(Type type,const APFloat & value,Location loc)222 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
223 return Base::getChecked(loc, type, value);
224 }
225
getValue() const226 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
227
getValueAsDouble() const228 double FloatAttr::getValueAsDouble() const {
229 return getValueAsDouble(getValue());
230 }
getValueAsDouble(APFloat value)231 double FloatAttr::getValueAsDouble(APFloat value) {
232 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
233 bool losesInfo = false;
234 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
235 &losesInfo);
236 }
237 return value.convertToDouble();
238 }
239
240 /// Verify construction invariants.
verifyFloatTypeInvariants(Location loc,Type type)241 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
242 if (!type.isa<FloatType>())
243 return emitError(loc, "expected floating point type");
244 return success();
245 }
246
verifyConstructionInvariants(Location loc,Type type,double value)247 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
248 double value) {
249 return verifyFloatTypeInvariants(loc, type);
250 }
251
verifyConstructionInvariants(Location loc,Type type,const APFloat & value)252 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
253 const APFloat &value) {
254 // Verify that the type is correct.
255 if (failed(verifyFloatTypeInvariants(loc, type)))
256 return failure();
257
258 // Verify that the type semantics match that of the value.
259 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
260 return emitError(
261 loc, "FloatAttr type doesn't match the type implied by its value");
262 }
263 return success();
264 }
265
266 //===----------------------------------------------------------------------===//
267 // SymbolRefAttr
268 //===----------------------------------------------------------------------===//
269
get(StringRef value,MLIRContext * ctx)270 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
271 return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
272 }
273
get(StringRef value,ArrayRef<FlatSymbolRefAttr> nestedReferences,MLIRContext * ctx)274 SymbolRefAttr SymbolRefAttr::get(StringRef value,
275 ArrayRef<FlatSymbolRefAttr> nestedReferences,
276 MLIRContext *ctx) {
277 return Base::get(ctx, value, nestedReferences);
278 }
279
getRootReference() const280 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
281
getLeafReference() const282 StringRef SymbolRefAttr::getLeafReference() const {
283 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
284 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
285 }
286
getNestedReferences() const287 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
288 return getImpl()->getNestedRefs();
289 }
290
291 //===----------------------------------------------------------------------===//
292 // IntegerAttr
293 //===----------------------------------------------------------------------===//
294
get(Type type,const APInt & value)295 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
296 if (type.isSignlessInteger(1))
297 return BoolAttr::get(value.getBoolValue(), type.getContext());
298 return Base::get(type.getContext(), type, value);
299 }
300
get(Type type,int64_t value)301 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
302 // This uses 64 bit APInts by default for index type.
303 if (type.isIndex())
304 return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
305
306 auto intType = type.cast<IntegerType>();
307 return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
308 }
309
getValue() const310 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
311
getInt() const312 int64_t IntegerAttr::getInt() const {
313 assert((getImpl()->getType().isIndex() ||
314 getImpl()->getType().isSignlessInteger()) &&
315 "must be signless integer");
316 return getValue().getSExtValue();
317 }
318
getSInt() const319 int64_t IntegerAttr::getSInt() const {
320 assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
321 return getValue().getSExtValue();
322 }
323
getUInt() const324 uint64_t IntegerAttr::getUInt() const {
325 assert(getImpl()->getType().isUnsignedInteger() &&
326 "must be unsigned integer");
327 return getValue().getZExtValue();
328 }
329
verifyIntegerTypeInvariants(Location loc,Type type)330 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
331 if (type.isa<IntegerType, IndexType>())
332 return success();
333 return emitError(loc, "expected integer or index type");
334 }
335
verifyConstructionInvariants(Location loc,Type type,int64_t value)336 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
337 int64_t value) {
338 return verifyIntegerTypeInvariants(loc, type);
339 }
340
verifyConstructionInvariants(Location loc,Type type,const APInt & value)341 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
342 const APInt &value) {
343 if (failed(verifyIntegerTypeInvariants(loc, type)))
344 return failure();
345 if (auto integerType = type.dyn_cast<IntegerType>())
346 if (integerType.getWidth() != value.getBitWidth())
347 return emitError(loc, "integer type bit width (")
348 << integerType.getWidth() << ") doesn't match value bit width ("
349 << value.getBitWidth() << ")";
350 return success();
351 }
352
353 //===----------------------------------------------------------------------===//
354 // BoolAttr
355
getValue() const356 bool BoolAttr::getValue() const {
357 auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
358 return storage->getValue().getBoolValue();
359 }
360
classof(Attribute attr)361 bool BoolAttr::classof(Attribute attr) {
362 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
363 return intAttr && intAttr.getType().isSignlessInteger(1);
364 }
365
366 //===----------------------------------------------------------------------===//
367 // IntegerSetAttr
368 //===----------------------------------------------------------------------===//
369
get(IntegerSet value)370 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
371 return Base::get(value.getConstraint(0).getContext(), value);
372 }
373
getValue() const374 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
375
376 //===----------------------------------------------------------------------===//
377 // OpaqueAttr
378 //===----------------------------------------------------------------------===//
379
get(Identifier dialect,StringRef attrData,Type type,MLIRContext * context)380 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
381 MLIRContext *context) {
382 return Base::get(context, dialect, attrData, type);
383 }
384
getChecked(Identifier dialect,StringRef attrData,Type type,Location location)385 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
386 Type type, Location location) {
387 return Base::getChecked(location, dialect, attrData, type);
388 }
389
390 /// Returns the dialect namespace of the opaque attribute.
getDialectNamespace() const391 Identifier OpaqueAttr::getDialectNamespace() const {
392 return getImpl()->dialectNamespace;
393 }
394
395 /// Returns the raw attribute data of the opaque attribute.
getAttrData() const396 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
397
398 /// Verify the construction of an opaque attribute.
verifyConstructionInvariants(Location loc,Identifier dialect,StringRef attrData,Type type)399 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
400 Identifier dialect,
401 StringRef attrData,
402 Type type) {
403 if (!Dialect::isValidNamespace(dialect.strref()))
404 return emitError(loc, "invalid dialect namespace '") << dialect << "'";
405 return success();
406 }
407
408 //===----------------------------------------------------------------------===//
409 // StringAttr
410 //===----------------------------------------------------------------------===//
411
get(StringRef bytes,MLIRContext * context)412 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
413 return get(bytes, NoneType::get(context));
414 }
415
416 /// Get an instance of a StringAttr with the given string and Type.
get(StringRef bytes,Type type)417 StringAttr StringAttr::get(StringRef bytes, Type type) {
418 return Base::get(type.getContext(), bytes, type);
419 }
420
getValue() const421 StringRef StringAttr::getValue() const { return getImpl()->value; }
422
423 //===----------------------------------------------------------------------===//
424 // TypeAttr
425 //===----------------------------------------------------------------------===//
426
get(Type value)427 TypeAttr TypeAttr::get(Type value) {
428 return Base::get(value.getContext(), value);
429 }
430
getValue() const431 Type TypeAttr::getValue() const { return getImpl()->value; }
432
433 //===----------------------------------------------------------------------===//
434 // ElementsAttr
435 //===----------------------------------------------------------------------===//
436
getType() const437 ShapedType ElementsAttr::getType() const {
438 return Attribute::getType().cast<ShapedType>();
439 }
440
441 /// Returns the number of elements held by this attribute.
getNumElements() const442 int64_t ElementsAttr::getNumElements() const {
443 return getType().getNumElements();
444 }
445
446 /// Return the value at the given index. If index does not refer to a valid
447 /// element, then a null attribute is returned.
getValue(ArrayRef<uint64_t> index) const448 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
449 if (auto denseAttr = dyn_cast<DenseElementsAttr>())
450 return denseAttr.getValue(index);
451 if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
452 return opaqueAttr.getValue(index);
453 return cast<SparseElementsAttr>().getValue(index);
454 }
455
456 /// Return if the given 'index' refers to a valid element in this attribute.
isValidIndex(ArrayRef<uint64_t> index) const457 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
458 auto type = getType();
459
460 // Verify that the rank of the indices matches the held type.
461 auto rank = type.getRank();
462 if (rank != static_cast<int64_t>(index.size()))
463 return false;
464
465 // Verify that all of the indices are within the shape dimensions.
466 auto shape = type.getShape();
467 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
468 return static_cast<int64_t>(index[i]) < shape[i];
469 });
470 }
471
472 ElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const473 ElementsAttr::mapValues(Type newElementType,
474 function_ref<APInt(const APInt &)> mapping) const {
475 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
476 return intOrFpAttr.mapValues(newElementType, mapping);
477 llvm_unreachable("unsupported ElementsAttr subtype");
478 }
479
480 ElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const481 ElementsAttr::mapValues(Type newElementType,
482 function_ref<APInt(const APFloat &)> mapping) const {
483 if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
484 return intOrFpAttr.mapValues(newElementType, mapping);
485 llvm_unreachable("unsupported ElementsAttr subtype");
486 }
487
488 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)489 bool ElementsAttr::classof(Attribute attr) {
490 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
491 OpaqueElementsAttr, SparseElementsAttr>();
492 }
493
494 /// Returns the 1 dimensional flattened row-major index from the given
495 /// multi-dimensional index.
getFlattenedIndex(ArrayRef<uint64_t> index) const496 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
497 assert(isValidIndex(index) && "expected valid multi-dimensional index");
498 auto type = getType();
499
500 // Reduce the provided multidimensional index into a flattended 1D row-major
501 // index.
502 auto rank = type.getRank();
503 auto shape = type.getShape();
504 uint64_t valueIndex = 0;
505 uint64_t dimMultiplier = 1;
506 for (int i = rank - 1; i >= 0; --i) {
507 valueIndex += index[i] * dimMultiplier;
508 dimMultiplier *= shape[i];
509 }
510 return valueIndex;
511 }
512
513 //===----------------------------------------------------------------------===//
514 // DenseElementsAttr Utilities
515 //===----------------------------------------------------------------------===//
516
517 /// Get the bitwidth of a dense element type within the buffer.
518 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
getDenseElementStorageWidth(size_t origWidth)519 static size_t getDenseElementStorageWidth(size_t origWidth) {
520 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
521 }
getDenseElementStorageWidth(Type elementType)522 static size_t getDenseElementStorageWidth(Type elementType) {
523 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
524 }
525
526 /// Set a bit to a specific value.
setBit(char * rawData,size_t bitPos,bool value)527 static void setBit(char *rawData, size_t bitPos, bool value) {
528 if (value)
529 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
530 else
531 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
532 }
533
534 /// Return the value of the specified bit.
getBit(const char * rawData,size_t bitPos)535 static bool getBit(const char *rawData, size_t bitPos) {
536 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
537 }
538
539 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
540 /// BE format.
copyAPIntToArrayForBEmachine(APInt value,size_t numBytes,char * result)541 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
542 char *result) {
543 assert(llvm::support::endian::system_endianness() == // NOLINT
544 llvm::support::endianness::big); // NOLINT
545 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
546
547 // Copy the words filled with data.
548 // For example, when `value` has 2 words, the first word is filled with data.
549 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
550 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
551 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
552 numFilledWords, result);
553 // Convert last word of APInt to LE format and store it in char
554 // array(`valueLE`).
555 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
556 size_t lastWordPos = numFilledWords;
557 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
558 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
559 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
560 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
561 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
562 // and store it in `result`.
563 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
564 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
565 valueLE.begin(), result + lastWordPos,
566 (numBytes - lastWordPos) * CHAR_BIT, 1);
567 }
568
569 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
570 /// format.
copyArrayToAPIntForBEmachine(const char * inArray,size_t numBytes,APInt & result)571 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
572 APInt &result) {
573 assert(llvm::support::endian::system_endianness() == // NOLINT
574 llvm::support::endianness::big); // NOLINT
575 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
576
577 // Copy the data that fills the word of `result` from `inArray`.
578 // For example, when `result` has 2 words, the first word will be filled with
579 // data. So, the first 8 bytes are copied from `inArray` here.
580 // `inArray` (10 bytes, BE): |abcdefgh|ij|
581 // ==> `result` (2 words, BE): |abcdefgh|--------|
582 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
583 std::copy_n(
584 inArray, numFilledWords,
585 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
586
587 // Convert array data which will be last word of `result` to LE format, and
588 // store it in char array(`inArrayLE`).
589 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
590 size_t lastWordPos = numFilledWords;
591 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
592 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
593 inArray + lastWordPos, inArrayLE.begin(),
594 (numBytes - lastWordPos) * CHAR_BIT, 1);
595
596 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
597 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
598 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
599 inArrayLE.begin(),
600 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
601 lastWordPos,
602 APInt::APINT_BITS_PER_WORD, 1);
603 }
604
605 /// Writes value to the bit position `bitPos` in array `rawData`.
writeBits(char * rawData,size_t bitPos,APInt value)606 static void writeBits(char *rawData, size_t bitPos, APInt value) {
607 size_t bitWidth = value.getBitWidth();
608
609 // If the bitwidth is 1 we just toggle the specific bit.
610 if (bitWidth == 1)
611 return setBit(rawData, bitPos, value.isOneValue());
612
613 // Otherwise, the bit position is guaranteed to be byte aligned.
614 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
615 if (llvm::support::endian::system_endianness() ==
616 llvm::support::endianness::big) {
617 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
618 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
619 // work correctly in BE format.
620 // ex. `value` (2 words including 10 bytes)
621 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
622 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
623 rawData + (bitPos / CHAR_BIT));
624 } else {
625 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
626 llvm::divideCeil(bitWidth, CHAR_BIT),
627 rawData + (bitPos / CHAR_BIT));
628 }
629 }
630
631 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
632 /// `rawData`.
readBits(const char * rawData,size_t bitPos,size_t bitWidth)633 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
634 // Handle a boolean bit position.
635 if (bitWidth == 1)
636 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
637
638 // Otherwise, the bit position must be 8-bit aligned.
639 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
640 APInt result(bitWidth, 0);
641 if (llvm::support::endian::system_endianness() ==
642 llvm::support::endianness::big) {
643 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
644 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
645 // work correctly in BE format.
646 // ex. `result` (2 words including 10 bytes)
647 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
648 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
649 llvm::divideCeil(bitWidth, CHAR_BIT), result);
650 } else {
651 std::copy_n(rawData + (bitPos / CHAR_BIT),
652 llvm::divideCeil(bitWidth, CHAR_BIT),
653 const_cast<char *>(
654 reinterpret_cast<const char *>(result.getRawData())));
655 }
656 return result;
657 }
658
659 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
660 /// the same element count as 'type'.
661 template <typename Values>
hasSameElementsOrSplat(ShapedType type,const Values & values)662 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
663 return (values.size() == 1) ||
664 (type.getNumElements() == static_cast<int64_t>(values.size()));
665 }
666
667 //===----------------------------------------------------------------------===//
668 // DenseElementsAttr Iterators
669 //===----------------------------------------------------------------------===//
670
671 //===----------------------------------------------------------------------===//
672 // AttributeElementIterator
673
AttributeElementIterator(DenseElementsAttr attr,size_t index)674 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
675 DenseElementsAttr attr, size_t index)
676 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
677 Attribute, Attribute, Attribute>(
678 attr.getAsOpaquePointer(), index) {}
679
operator *() const680 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
681 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
682 Type eltTy = owner.getType().getElementType();
683 if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
684 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
685 if (eltTy.isa<IndexType>())
686 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
687 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
688 IntElementIterator intIt(owner, index);
689 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
690 return FloatAttr::get(eltTy, *floatIt);
691 }
692 if (owner.isa<DenseStringElementsAttr>()) {
693 ArrayRef<StringRef> vals = owner.getRawStringData();
694 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
695 }
696 llvm_unreachable("unexpected element type");
697 }
698
699 //===----------------------------------------------------------------------===//
700 // BoolElementIterator
701
BoolElementIterator(DenseElementsAttr attr,size_t dataIndex)702 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
703 DenseElementsAttr attr, size_t dataIndex)
704 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
705 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
706
operator *() const707 bool DenseElementsAttr::BoolElementIterator::operator*() const {
708 return getBit(getData(), getDataIndex());
709 }
710
711 //===----------------------------------------------------------------------===//
712 // IntElementIterator
713
IntElementIterator(DenseElementsAttr attr,size_t dataIndex)714 DenseElementsAttr::IntElementIterator::IntElementIterator(
715 DenseElementsAttr attr, size_t dataIndex)
716 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
717 attr.getRawData().data(), attr.isSplat(), dataIndex),
718 bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
719
operator *() const720 APInt DenseElementsAttr::IntElementIterator::operator*() const {
721 return readBits(getData(),
722 getDataIndex() * getDenseElementStorageWidth(bitWidth),
723 bitWidth);
724 }
725
726 //===----------------------------------------------------------------------===//
727 // ComplexIntElementIterator
728
ComplexIntElementIterator(DenseElementsAttr attr,size_t dataIndex)729 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
730 DenseElementsAttr attr, size_t dataIndex)
731 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
732 std::complex<APInt>, std::complex<APInt>,
733 std::complex<APInt>>(
734 attr.getRawData().data(), attr.isSplat(), dataIndex) {
735 auto complexType = attr.getType().getElementType().cast<ComplexType>();
736 bitWidth = getDenseElementBitWidth(complexType.getElementType());
737 }
738
739 std::complex<APInt>
operator *() const740 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
741 size_t storageWidth = getDenseElementStorageWidth(bitWidth);
742 size_t offset = getDataIndex() * storageWidth * 2;
743 return {readBits(getData(), offset, bitWidth),
744 readBits(getData(), offset + storageWidth, bitWidth)};
745 }
746
747 //===----------------------------------------------------------------------===//
748 // FloatElementIterator
749
FloatElementIterator(const llvm::fltSemantics & smt,IntElementIterator it)750 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
751 const llvm::fltSemantics &smt, IntElementIterator it)
752 : llvm::mapped_iterator<IntElementIterator,
753 std::function<APFloat(const APInt &)>>(
754 it, [&](const APInt &val) { return APFloat(smt, val); }) {}
755
756 //===----------------------------------------------------------------------===//
757 // ComplexFloatElementIterator
758
ComplexFloatElementIterator(const llvm::fltSemantics & smt,ComplexIntElementIterator it)759 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
760 const llvm::fltSemantics &smt, ComplexIntElementIterator it)
761 : llvm::mapped_iterator<
762 ComplexIntElementIterator,
763 std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
764 it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
765 return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
766 }) {}
767
768 //===----------------------------------------------------------------------===//
769 // DenseElementsAttr
770 //===----------------------------------------------------------------------===//
771
772 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)773 bool DenseElementsAttr::classof(Attribute attr) {
774 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
775 }
776
get(ShapedType type,ArrayRef<Attribute> values)777 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
778 ArrayRef<Attribute> values) {
779 assert(hasSameElementsOrSplat(type, values));
780
781 // If the element type is not based on int/float/index, assume it is a string
782 // type.
783 auto eltType = type.getElementType();
784 if (!type.getElementType().isIntOrIndexOrFloat()) {
785 SmallVector<StringRef, 8> stringValues;
786 stringValues.reserve(values.size());
787 for (Attribute attr : values) {
788 assert(attr.isa<StringAttr>() &&
789 "expected string value for non integer/index/float element");
790 stringValues.push_back(attr.cast<StringAttr>().getValue());
791 }
792 return get(type, stringValues);
793 }
794
795 // Otherwise, get the raw storage width to use for the allocation.
796 size_t bitWidth = getDenseElementBitWidth(eltType);
797 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
798
799 // Compress the attribute values into a character buffer.
800 SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
801 values.size());
802 APInt intVal;
803 for (unsigned i = 0, e = values.size(); i < e; ++i) {
804 assert(eltType == values[i].getType() &&
805 "expected attribute value to have element type");
806 if (eltType.isa<FloatType>())
807 intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
808 else if (eltType.isa<IntegerType>())
809 intVal = values[i].cast<IntegerAttr>().getValue();
810 else
811 llvm_unreachable("unexpected element type");
812
813 assert(intVal.getBitWidth() == bitWidth &&
814 "expected value to have same bitwidth as element type");
815 writeBits(data.data(), i * storageBitWidth, intVal);
816 }
817 return DenseIntOrFPElementsAttr::getRaw(type, data,
818 /*isSplat=*/(values.size() == 1));
819 }
820
get(ShapedType type,ArrayRef<bool> values)821 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
822 ArrayRef<bool> values) {
823 assert(hasSameElementsOrSplat(type, values));
824 assert(type.getElementType().isInteger(1));
825
826 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
827 for (int i = 0, e = values.size(); i != e; ++i)
828 setBit(buff.data(), i, values[i]);
829 return DenseIntOrFPElementsAttr::getRaw(type, buff,
830 /*isSplat=*/(values.size() == 1));
831 }
832
get(ShapedType type,ArrayRef<StringRef> values)833 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
834 ArrayRef<StringRef> values) {
835 assert(!type.getElementType().isIntOrFloat());
836 return DenseStringElementsAttr::get(type, values);
837 }
838
839 /// Constructs a dense integer elements attribute from an array of APInt
840 /// values. Each APInt value is expected to have the same bitwidth as the
841 /// element type of 'type'.
get(ShapedType type,ArrayRef<APInt> values)842 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
843 ArrayRef<APInt> values) {
844 assert(type.getElementType().isIntOrIndex());
845 assert(hasSameElementsOrSplat(type, values));
846 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
847 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
848 /*isSplat=*/(values.size() == 1));
849 }
get(ShapedType type,ArrayRef<std::complex<APInt>> values)850 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
851 ArrayRef<std::complex<APInt>> values) {
852 ComplexType complex = type.getElementType().cast<ComplexType>();
853 assert(complex.getElementType().isa<IntegerType>());
854 assert(hasSameElementsOrSplat(type, values));
855 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
856 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
857 values.size() * 2);
858 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
859 /*isSplat=*/(values.size() == 1));
860 }
861
862 // Constructs a dense float elements attribute from an array of APFloat
863 // values. Each APFloat value is expected to have the same bitwidth as the
864 // element type of 'type'.
get(ShapedType type,ArrayRef<APFloat> values)865 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
866 ArrayRef<APFloat> values) {
867 assert(type.getElementType().isa<FloatType>());
868 assert(hasSameElementsOrSplat(type, values));
869 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
870 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
871 /*isSplat=*/(values.size() == 1));
872 }
873 DenseElementsAttr
get(ShapedType type,ArrayRef<std::complex<APFloat>> values)874 DenseElementsAttr::get(ShapedType type,
875 ArrayRef<std::complex<APFloat>> values) {
876 ComplexType complex = type.getElementType().cast<ComplexType>();
877 assert(complex.getElementType().isa<FloatType>());
878 assert(hasSameElementsOrSplat(type, values));
879 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
880 values.size() * 2);
881 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
882 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
883 /*isSplat=*/(values.size() == 1));
884 }
885
886 /// Construct a dense elements attribute from a raw buffer representing the
887 /// data for this attribute. Users should generally not use this methods as
888 /// the expected buffer format may not be a form the user expects.
getFromRawBuffer(ShapedType type,ArrayRef<char> rawBuffer,bool isSplatBuffer)889 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
890 ArrayRef<char> rawBuffer,
891 bool isSplatBuffer) {
892 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
893 }
894
895 /// Returns true if the given buffer is a valid raw buffer for the given type.
isValidRawBuffer(ShapedType type,ArrayRef<char> rawBuffer,bool & detectedSplat)896 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
897 ArrayRef<char> rawBuffer,
898 bool &detectedSplat) {
899 size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
900 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
901
902 // Storage width of 1 is special as it is packed by the bit.
903 if (storageWidth == 1) {
904 // Check for a splat, or a buffer equal to the number of elements.
905 if ((detectedSplat = rawBuffer.size() == 1))
906 return true;
907 return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
908 }
909 // All other types are 8-bit aligned.
910 if ((detectedSplat = rawBufferWidth == storageWidth))
911 return true;
912 return rawBufferWidth == (storageWidth * type.getNumElements());
913 }
914
915 /// Check the information for a C++ data type, check if this type is valid for
916 /// the current attribute. This method is used to verify specific type
917 /// invariants that the templatized 'getValues' method cannot.
isValidIntOrFloat(Type type,int64_t dataEltSize,bool isInt,bool isSigned)918 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
919 bool isSigned) {
920 // Make sure that the data element size is the same as the type element width.
921 if (getDenseElementBitWidth(type) !=
922 static_cast<size_t>(dataEltSize * CHAR_BIT))
923 return false;
924
925 // Check that the element type is either float or integer or index.
926 if (!isInt)
927 return type.isa<FloatType>();
928 if (type.isIndex())
929 return true;
930
931 auto intType = type.dyn_cast<IntegerType>();
932 if (!intType)
933 return false;
934
935 // Make sure signedness semantics is consistent.
936 if (intType.isSignless())
937 return true;
938 return intType.isSigned() ? isSigned : !isSigned;
939 }
940
941 /// Defaults down the subclass implementation.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)942 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
943 ArrayRef<char> data,
944 int64_t dataEltSize,
945 bool isInt, bool isSigned) {
946 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
947 isSigned);
948 }
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)949 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
950 ArrayRef<char> data,
951 int64_t dataEltSize,
952 bool isInt,
953 bool isSigned) {
954 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
955 isInt, isSigned);
956 }
957
958 /// A method used to verify specific type invariants that the templatized 'get'
959 /// method cannot.
isValidIntOrFloat(int64_t dataEltSize,bool isInt,bool isSigned) const960 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
961 bool isSigned) const {
962 return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
963 isSigned);
964 }
965
966 /// Check the information for a C++ data type, check if this type is valid for
967 /// the current attribute.
isValidComplex(int64_t dataEltSize,bool isInt,bool isSigned) const968 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
969 bool isSigned) const {
970 return ::isValidIntOrFloat(
971 getType().getElementType().cast<ComplexType>().getElementType(),
972 dataEltSize / 2, isInt, isSigned);
973 }
974
975 /// Returns true if this attribute corresponds to a splat, i.e. if all element
976 /// values are the same.
isSplat() const977 bool DenseElementsAttr::isSplat() const {
978 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
979 }
980
981 /// Return the held element values as a range of Attributes.
getAttributeValues() const982 auto DenseElementsAttr::getAttributeValues() const
983 -> llvm::iterator_range<AttributeElementIterator> {
984 return {attr_value_begin(), attr_value_end()};
985 }
attr_value_begin() const986 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
987 return AttributeElementIterator(*this, 0);
988 }
attr_value_end() const989 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
990 return AttributeElementIterator(*this, getNumElements());
991 }
992
993 /// Return the held element values as a range of bool. The element type of
994 /// this attribute must be of integer type of bitwidth 1.
getBoolValues() const995 auto DenseElementsAttr::getBoolValues() const
996 -> llvm::iterator_range<BoolElementIterator> {
997 auto eltType = getType().getElementType().dyn_cast<IntegerType>();
998 assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
999 (void)eltType;
1000 return {BoolElementIterator(*this, 0),
1001 BoolElementIterator(*this, getNumElements())};
1002 }
1003
1004 /// Return the held element values as a range of APInts. The element type of
1005 /// this attribute must be of integer type.
getIntValues() const1006 auto DenseElementsAttr::getIntValues() const
1007 -> llvm::iterator_range<IntElementIterator> {
1008 assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1009 return {raw_int_begin(), raw_int_end()};
1010 }
int_value_begin() const1011 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
1012 assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1013 return raw_int_begin();
1014 }
int_value_end() const1015 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
1016 assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1017 return raw_int_end();
1018 }
getComplexIntValues() const1019 auto DenseElementsAttr::getComplexIntValues() const
1020 -> llvm::iterator_range<ComplexIntElementIterator> {
1021 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1022 (void)eltTy;
1023 assert(eltTy.isa<IntegerType>() && "expected complex integral type");
1024 return {ComplexIntElementIterator(*this, 0),
1025 ComplexIntElementIterator(*this, getNumElements())};
1026 }
1027
1028 /// Return the held element values as a range of APFloat. The element type of
1029 /// this attribute must be of float type.
getFloatValues() const1030 auto DenseElementsAttr::getFloatValues() const
1031 -> llvm::iterator_range<FloatElementIterator> {
1032 auto elementType = getType().getElementType().cast<FloatType>();
1033 const auto &elementSemantics = elementType.getFloatSemantics();
1034 return {FloatElementIterator(elementSemantics, raw_int_begin()),
1035 FloatElementIterator(elementSemantics, raw_int_end())};
1036 }
float_value_begin() const1037 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1038 return getFloatValues().begin();
1039 }
float_value_end() const1040 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1041 return getFloatValues().end();
1042 }
getComplexFloatValues() const1043 auto DenseElementsAttr::getComplexFloatValues() const
1044 -> llvm::iterator_range<ComplexFloatElementIterator> {
1045 Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1046 assert(eltTy.isa<FloatType>() && "expected complex float type");
1047 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1048 return {{semantics, {*this, 0}},
1049 {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1050 }
1051
1052 /// Return the raw storage data held by this attribute.
getRawData() const1053 ArrayRef<char> DenseElementsAttr::getRawData() const {
1054 return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1055 }
1056
getRawStringData() const1057 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1058 return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1059 }
1060
1061 /// Return a new DenseElementsAttr that has the same data as the current
1062 /// attribute, but has been reshaped to 'newType'. The new type must have the
1063 /// same total number of elements as well as element type.
reshape(ShapedType newType)1064 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1065 ShapedType curType = getType();
1066 if (curType == newType)
1067 return *this;
1068
1069 (void)curType;
1070 assert(newType.getElementType() == curType.getElementType() &&
1071 "expected the same element type");
1072 assert(newType.getNumElements() == curType.getNumElements() &&
1073 "expected the same number of elements");
1074 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1075 }
1076
1077 DenseElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1078 DenseElementsAttr::mapValues(Type newElementType,
1079 function_ref<APInt(const APInt &)> mapping) const {
1080 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1081 }
1082
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1083 DenseElementsAttr DenseElementsAttr::mapValues(
1084 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1085 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1086 }
1087
1088 //===----------------------------------------------------------------------===//
1089 // DenseStringElementsAttr
1090 //===----------------------------------------------------------------------===//
1091
1092 DenseStringElementsAttr
get(ShapedType type,ArrayRef<StringRef> values)1093 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1094 return Base::get(type.getContext(), type, values, (values.size() == 1));
1095 }
1096
1097 //===----------------------------------------------------------------------===//
1098 // DenseIntOrFPElementsAttr
1099 //===----------------------------------------------------------------------===//
1100
1101 /// Utility method to write a range of APInt values to a buffer.
1102 template <typename APRangeT>
writeAPIntsToBuffer(size_t storageWidth,std::vector<char> & data,APRangeT && values)1103 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1104 APRangeT &&values) {
1105 data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1106 size_t offset = 0;
1107 for (auto it = values.begin(), e = values.end(); it != e;
1108 ++it, offset += storageWidth) {
1109 assert((*it).getBitWidth() <= storageWidth);
1110 writeBits(data.data(), offset, *it);
1111 }
1112 }
1113
1114 /// Constructs a dense elements attribute from an array of raw APFloat values.
1115 /// Each APFloat value is expected to have the same bitwidth as the element
1116 /// type of 'type'. 'type' must be a vector or tensor with static shape.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APFloat> values,bool isSplat)1117 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1118 size_t storageWidth,
1119 ArrayRef<APFloat> values,
1120 bool isSplat) {
1121 std::vector<char> data;
1122 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1123 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1124 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1125 }
1126
1127 /// Constructs a dense elements attribute from an array of raw APInt values.
1128 /// Each APInt value is expected to have the same bitwidth as the element type
1129 /// of 'type'.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APInt> values,bool isSplat)1130 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1131 size_t storageWidth,
1132 ArrayRef<APInt> values,
1133 bool isSplat) {
1134 std::vector<char> data;
1135 writeAPIntsToBuffer(storageWidth, data, values);
1136 return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1137 }
1138
getRaw(ShapedType type,ArrayRef<char> data,bool isSplat)1139 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1140 ArrayRef<char> data,
1141 bool isSplat) {
1142 assert((type.isa<RankedTensorType, VectorType>()) &&
1143 "type must be ranked tensor or vector");
1144 assert(type.hasStaticShape() && "type must have static shape");
1145 return Base::get(type.getContext(), type, data, isSplat);
1146 }
1147
1148 /// Overload of the raw 'get' method that asserts that the given type is of
1149 /// complex type. This method is used to verify type invariants that the
1150 /// templatized 'get' method cannot.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1151 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1152 ArrayRef<char> data,
1153 int64_t dataEltSize,
1154 bool isInt,
1155 bool isSigned) {
1156 assert(::isValidIntOrFloat(
1157 type.getElementType().cast<ComplexType>().getElementType(),
1158 dataEltSize / 2, isInt, isSigned));
1159
1160 int64_t numElements = data.size() / dataEltSize;
1161 assert(numElements == 1 || numElements == type.getNumElements());
1162 return getRaw(type, data, /*isSplat=*/numElements == 1);
1163 }
1164
1165 /// Overload of the 'getRaw' method that asserts that the given type is of
1166 /// integer type. This method is used to verify type invariants that the
1167 /// templatized 'get' method cannot.
1168 DenseElementsAttr
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1169 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1170 int64_t dataEltSize, bool isInt,
1171 bool isSigned) {
1172 assert(
1173 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1174
1175 int64_t numElements = data.size() / dataEltSize;
1176 assert(numElements == 1 || numElements == type.getNumElements());
1177 return getRaw(type, data, /*isSplat=*/numElements == 1);
1178 }
1179
convertEndianOfCharForBEmachine(const char * inRawData,char * outRawData,size_t elementBitWidth,size_t numElements)1180 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1181 const char *inRawData, char *outRawData, size_t elementBitWidth,
1182 size_t numElements) {
1183 using llvm::support::ulittle16_t;
1184 using llvm::support::ulittle32_t;
1185 using llvm::support::ulittle64_t;
1186
1187 assert(llvm::support::endian::system_endianness() == // NOLINT
1188 llvm::support::endianness::big); // NOLINT
1189 // NOLINT to avoid warning message about replacing by static_assert()
1190
1191 // Following std::copy_n always converts endianness on BE machine.
1192 switch (elementBitWidth) {
1193 case 16: {
1194 const ulittle16_t *inRawDataPos =
1195 reinterpret_cast<const ulittle16_t *>(inRawData);
1196 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1197 std::copy_n(inRawDataPos, numElements, outDataPos);
1198 break;
1199 }
1200 case 32: {
1201 const ulittle32_t *inRawDataPos =
1202 reinterpret_cast<const ulittle32_t *>(inRawData);
1203 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1204 std::copy_n(inRawDataPos, numElements, outDataPos);
1205 break;
1206 }
1207 case 64: {
1208 const ulittle64_t *inRawDataPos =
1209 reinterpret_cast<const ulittle64_t *>(inRawData);
1210 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1211 std::copy_n(inRawDataPos, numElements, outDataPos);
1212 break;
1213 }
1214 default: {
1215 size_t nBytes = elementBitWidth / CHAR_BIT;
1216 for (size_t i = 0; i < nBytes; i++)
1217 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1218 break;
1219 }
1220 }
1221 }
1222
convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,MutableArrayRef<char> outRawData,ShapedType type)1223 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1224 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1225 ShapedType type) {
1226 size_t numElements = type.getNumElements();
1227 Type elementType = type.getElementType();
1228 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1229 elementType = complexTy.getElementType();
1230 numElements = numElements * 2;
1231 }
1232 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1233 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1234 inRawData.size() <= outRawData.size());
1235 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1236 elementBitWidth, numElements);
1237 }
1238
1239 //===----------------------------------------------------------------------===//
1240 // DenseFPElementsAttr
1241 //===----------------------------------------------------------------------===//
1242
1243 template <typename Fn, typename Attr>
mappingHelper(Fn mapping,Attr & attr,ShapedType inType,Type newElementType,llvm::SmallVectorImpl<char> & data)1244 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1245 Type newElementType,
1246 llvm::SmallVectorImpl<char> &data) {
1247 size_t bitWidth = getDenseElementBitWidth(newElementType);
1248 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1249
1250 ShapedType newArrayType;
1251 if (inType.isa<RankedTensorType>())
1252 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1253 else if (inType.isa<UnrankedTensorType>())
1254 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1255 else if (inType.isa<VectorType>())
1256 newArrayType = VectorType::get(inType.getShape(), newElementType);
1257 else
1258 assert(newArrayType && "Unhandled tensor type");
1259
1260 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1261 data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1262
1263 // Functor used to process a single element value of the attribute.
1264 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1265 auto newInt = mapping(value);
1266 assert(newInt.getBitWidth() == bitWidth);
1267 writeBits(data.data(), index * storageBitWidth, newInt);
1268 };
1269
1270 // Check for the splat case.
1271 if (attr.isSplat()) {
1272 processElt(*attr.begin(), /*index=*/0);
1273 return newArrayType;
1274 }
1275
1276 // Otherwise, process all of the element values.
1277 uint64_t elementIdx = 0;
1278 for (auto value : attr)
1279 processElt(value, elementIdx++);
1280 return newArrayType;
1281 }
1282
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1283 DenseElementsAttr DenseFPElementsAttr::mapValues(
1284 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1285 llvm::SmallVector<char, 8> elementData;
1286 auto newArrayType =
1287 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1288
1289 return getRaw(newArrayType, elementData, isSplat());
1290 }
1291
1292 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1293 bool DenseFPElementsAttr::classof(Attribute attr) {
1294 return attr.isa<DenseElementsAttr>() &&
1295 attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1296 }
1297
1298 //===----------------------------------------------------------------------===//
1299 // DenseIntElementsAttr
1300 //===----------------------------------------------------------------------===//
1301
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1302 DenseElementsAttr DenseIntElementsAttr::mapValues(
1303 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1304 llvm::SmallVector<char, 8> elementData;
1305 auto newArrayType =
1306 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1307
1308 return getRaw(newArrayType, elementData, isSplat());
1309 }
1310
1311 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1312 bool DenseIntElementsAttr::classof(Attribute attr) {
1313 return attr.isa<DenseElementsAttr>() &&
1314 attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1315 }
1316
1317 //===----------------------------------------------------------------------===//
1318 // OpaqueElementsAttr
1319 //===----------------------------------------------------------------------===//
1320
get(Dialect * dialect,ShapedType type,StringRef bytes)1321 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1322 StringRef bytes) {
1323 assert(TensorType::isValidElementType(type.getElementType()) &&
1324 "Input element type should be a valid tensor element type");
1325 return Base::get(type.getContext(), type, dialect, bytes);
1326 }
1327
getValue() const1328 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1329
1330 /// Return the value at the given index. If index does not refer to a valid
1331 /// element, then a null attribute is returned.
getValue(ArrayRef<uint64_t> index) const1332 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1333 assert(isValidIndex(index) && "expected valid multi-dimensional index");
1334 return Attribute();
1335 }
1336
getDialect() const1337 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1338
decode(ElementsAttr & result)1339 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1340 auto *d = getDialect();
1341 if (!d)
1342 return true;
1343 auto *interface =
1344 d->getRegisteredInterface<DialectDecodeAttributesInterface>();
1345 if (!interface)
1346 return true;
1347 return failed(interface->decode(*this, result));
1348 }
1349
1350 //===----------------------------------------------------------------------===//
1351 // SparseElementsAttr
1352 //===----------------------------------------------------------------------===//
1353
get(ShapedType type,DenseElementsAttr indices,DenseElementsAttr values)1354 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1355 DenseElementsAttr indices,
1356 DenseElementsAttr values) {
1357 assert(indices.getType().getElementType().isInteger(64) &&
1358 "expected sparse indices to be 64-bit integer values");
1359 assert((type.isa<RankedTensorType, VectorType>()) &&
1360 "type must be ranked tensor or vector");
1361 assert(type.hasStaticShape() && "type must have static shape");
1362 return Base::get(type.getContext(), type,
1363 indices.cast<DenseIntElementsAttr>(), values);
1364 }
1365
getIndices() const1366 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1367 return getImpl()->indices;
1368 }
1369
getValues() const1370 DenseElementsAttr SparseElementsAttr::getValues() const {
1371 return getImpl()->values;
1372 }
1373
1374 /// Return the value of the element at the given index.
getValue(ArrayRef<uint64_t> index) const1375 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1376 assert(isValidIndex(index) && "expected valid multi-dimensional index");
1377 auto type = getType();
1378
1379 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1380 // as a 1-D index array.
1381 auto sparseIndices = getIndices();
1382 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1383
1384 // Check to see if the indices are a splat.
1385 if (sparseIndices.isSplat()) {
1386 // If the index is also not a splat of the index value, we know that the
1387 // value is zero.
1388 auto splatIndex = *sparseIndexValues.begin();
1389 if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1390 return getZeroAttr();
1391
1392 // If the indices are a splat, we also expect the values to be a splat.
1393 assert(getValues().isSplat() && "expected splat values");
1394 return getValues().getSplatValue();
1395 }
1396
1397 // Build a mapping between known indices and the offset of the stored element.
1398 llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1399 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1400 size_t rank = type.getRank();
1401 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1402 mappedIndices.try_emplace(
1403 {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1404
1405 // Look for the provided index key within the mapped indices. If the provided
1406 // index is not found, then return a zero attribute.
1407 auto it = mappedIndices.find(index);
1408 if (it == mappedIndices.end())
1409 return getZeroAttr();
1410
1411 // Otherwise, return the held sparse value element.
1412 return getValues().getValue(it->second);
1413 }
1414
1415 /// Get a zero APFloat for the given sparse attribute.
getZeroAPFloat() const1416 APFloat SparseElementsAttr::getZeroAPFloat() const {
1417 auto eltType = getType().getElementType().cast<FloatType>();
1418 return APFloat(eltType.getFloatSemantics());
1419 }
1420
1421 /// Get a zero APInt for the given sparse attribute.
getZeroAPInt() const1422 APInt SparseElementsAttr::getZeroAPInt() const {
1423 auto eltType = getType().getElementType().cast<IntegerType>();
1424 return APInt::getNullValue(eltType.getWidth());
1425 }
1426
1427 /// Get a zero attribute for the given attribute type.
getZeroAttr() const1428 Attribute SparseElementsAttr::getZeroAttr() const {
1429 auto eltType = getType().getElementType();
1430
1431 // Handle floating point elements.
1432 if (eltType.isa<FloatType>())
1433 return FloatAttr::get(eltType, 0);
1434
1435 // Otherwise, this is an integer.
1436 // TODO: Handle StringAttr here.
1437 return IntegerAttr::get(eltType, 0);
1438 }
1439
1440 /// Flatten, and return, all of the sparse indices in this attribute in
1441 /// row-major order.
getFlattenedSparseIndices() const1442 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1443 std::vector<ptrdiff_t> flatSparseIndices;
1444
1445 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1446 // as a 1-D index array.
1447 auto sparseIndices = getIndices();
1448 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1449 if (sparseIndices.isSplat()) {
1450 SmallVector<uint64_t, 8> indices(getType().getRank(),
1451 *sparseIndexValues.begin());
1452 flatSparseIndices.push_back(getFlattenedIndex(indices));
1453 return flatSparseIndices;
1454 }
1455
1456 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1457 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1458 size_t rank = getType().getRank();
1459 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1460 flatSparseIndices.push_back(getFlattenedIndex(
1461 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1462 return flatSparseIndices;
1463 }
1464
1465 //===----------------------------------------------------------------------===//
1466 // MutableDictionaryAttr
1467 //===----------------------------------------------------------------------===//
1468
MutableDictionaryAttr(ArrayRef<NamedAttribute> attributes)1469 MutableDictionaryAttr::MutableDictionaryAttr(
1470 ArrayRef<NamedAttribute> attributes) {
1471 setAttrs(attributes);
1472 }
1473
1474 /// Return the underlying dictionary attribute.
1475 DictionaryAttr
getDictionary(MLIRContext * context) const1476 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1477 // Construct empty DictionaryAttr if needed.
1478 if (!attrs)
1479 return DictionaryAttr::get({}, context);
1480 return attrs;
1481 }
1482
getAttrs() const1483 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1484 return attrs ? attrs.getValue() : llvm::None;
1485 }
1486
1487 /// Replace the held attributes with ones provided in 'newAttrs'.
setAttrs(ArrayRef<NamedAttribute> attributes)1488 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1489 // Don't create an attribute list if there are no attributes.
1490 if (attributes.empty())
1491 attrs = nullptr;
1492 else
1493 attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1494 }
1495
1496 /// Return the specified attribute if present, null otherwise.
get(StringRef name) const1497 Attribute MutableDictionaryAttr::get(StringRef name) const {
1498 return attrs ? attrs.get(name) : nullptr;
1499 }
1500
1501 /// Return the specified attribute if present, null otherwise.
get(Identifier name) const1502 Attribute MutableDictionaryAttr::get(Identifier name) const {
1503 return attrs ? attrs.get(name) : nullptr;
1504 }
1505
1506 /// Return the specified named attribute if present, None otherwise.
getNamed(StringRef name) const1507 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1508 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1509 }
1510 Optional<NamedAttribute>
getNamed(Identifier name) const1511 MutableDictionaryAttr::getNamed(Identifier name) const {
1512 return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1513 }
1514
1515 /// If the an attribute exists with the specified name, change it to the new
1516 /// value. Otherwise, add a new attribute with the specified name/value.
set(Identifier name,Attribute value)1517 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1518 assert(value && "attributes may never be null");
1519
1520 // Look for an existing value for the given name, and set it in-place.
1521 ArrayRef<NamedAttribute> values = getAttrs();
1522 const auto *it = llvm::find_if(
1523 values, [name](NamedAttribute attr) { return attr.first == name; });
1524 if (it != values.end()) {
1525 // Bail out early if the value is the same as what we already have.
1526 if (it->second == value)
1527 return;
1528
1529 SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1530 newAttrs[it - values.begin()].second = value;
1531 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1532 return;
1533 }
1534
1535 // Otherwise, insert the new attribute into its sorted position.
1536 it = llvm::lower_bound(values, name);
1537 SmallVector<NamedAttribute, 8> newAttrs;
1538 newAttrs.reserve(values.size() + 1);
1539 newAttrs.append(values.begin(), it);
1540 newAttrs.push_back({name, value});
1541 newAttrs.append(it, values.end());
1542 attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1543 }
1544
1545 /// Remove the attribute with the specified name if it exists. The return
1546 /// value indicates whether the attribute was present or not.
remove(Identifier name)1547 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1548 auto origAttrs = getAttrs();
1549 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1550 if (origAttrs[i].first == name) {
1551 // Handle the simple case of removing the only attribute in the list.
1552 if (e == 1) {
1553 attrs = nullptr;
1554 return RemoveResult::Removed;
1555 }
1556
1557 SmallVector<NamedAttribute, 8> newAttrs;
1558 newAttrs.reserve(origAttrs.size() - 1);
1559 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1560 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1561 attrs = DictionaryAttr::getWithSorted(newAttrs,
1562 newAttrs[0].second.getContext());
1563 return RemoveResult::Removed;
1564 }
1565 }
1566 return RemoveResult::NotFound;
1567 }
1568