• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "AttributeDetail.h"
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Diagnostics.h"
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Types.h"
16 #include "mlir/Interfaces/DecodeAttributesInterfaces.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 //===----------------------------------------------------------------------===//
25 // AffineMapAttr
26 //===----------------------------------------------------------------------===//
27 
get(AffineMap value)28 AffineMapAttr AffineMapAttr::get(AffineMap value) {
29   return Base::get(value.getContext(), value);
30 }
31 
getValue() const32 AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
33 
34 //===----------------------------------------------------------------------===//
35 // ArrayAttr
36 //===----------------------------------------------------------------------===//
37 
get(ArrayRef<Attribute> value,MLIRContext * context)38 ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
39   return Base::get(context, value);
40 }
41 
getValue() const42 ArrayRef<Attribute> ArrayAttr::getValue() const { return getImpl()->value; }
43 
operator [](unsigned idx) const44 Attribute ArrayAttr::operator[](unsigned idx) const {
45   assert(idx < size() && "index out of bounds");
46   return getValue()[idx];
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // DictionaryAttr
51 //===----------------------------------------------------------------------===//
52 
53 /// Helper function that does either an in place sort or sorts from source array
54 /// into destination. If inPlace then storage is both the source and the
55 /// destination, else value is the source and storage destination. Returns
56 /// whether source was sorted.
57 template <bool inPlace>
dictionaryAttrSort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)58 static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
59                                SmallVectorImpl<NamedAttribute> &storage) {
60   // Specialize for the common case.
61   switch (value.size()) {
62   case 0:
63     // Zero already sorted.
64     break;
65   case 1:
66     // One already sorted but may need to be copied.
67     if (!inPlace)
68       storage.assign({value[0]});
69     break;
70   case 2: {
71     bool isSorted = value[0] < value[1];
72     if (inPlace) {
73       if (!isSorted)
74         std::swap(storage[0], storage[1]);
75     } else if (isSorted) {
76       storage.assign({value[0], value[1]});
77     } else {
78       storage.assign({value[1], value[0]});
79     }
80     return !isSorted;
81   }
82   default:
83     if (!inPlace)
84       storage.assign(value.begin(), value.end());
85     // Check to see they are sorted already.
86     bool isSorted = llvm::is_sorted(value);
87     if (!isSorted) {
88       // If not, do a general sort.
89       llvm::array_pod_sort(storage.begin(), storage.end());
90       value = storage;
91     }
92     return !isSorted;
93   }
94   return false;
95 }
96 
97 /// Returns an entry with a duplicate name from the given sorted array of named
98 /// attributes. Returns llvm::None if all elements have unique names.
99 static Optional<NamedAttribute>
findDuplicateElement(ArrayRef<NamedAttribute> value)100 findDuplicateElement(ArrayRef<NamedAttribute> value) {
101   const Optional<NamedAttribute> none{llvm::None};
102   if (value.size() < 2)
103     return none;
104 
105   if (value.size() == 2)
106     return value[0].first == value[1].first ? value[0] : none;
107 
108   auto it = std::adjacent_find(
109       value.begin(), value.end(),
110       [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; });
111   return it != value.end() ? *it : none;
112 }
113 
sort(ArrayRef<NamedAttribute> value,SmallVectorImpl<NamedAttribute> & storage)114 bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
115                           SmallVectorImpl<NamedAttribute> &storage) {
116   bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
117   assert(!findDuplicateElement(storage) &&
118          "DictionaryAttr element names must be unique");
119   return isSorted;
120 }
121 
sortInPlace(SmallVectorImpl<NamedAttribute> & array)122 bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
123   bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
124   assert(!findDuplicateElement(array) &&
125          "DictionaryAttr element names must be unique");
126   return isSorted;
127 }
128 
129 Optional<NamedAttribute>
findDuplicate(SmallVectorImpl<NamedAttribute> & array,bool isSorted)130 DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
131                               bool isSorted) {
132   if (!isSorted)
133     dictionaryAttrSort</*inPlace=*/true>(array, array);
134   return findDuplicateElement(array);
135 }
136 
get(ArrayRef<NamedAttribute> value,MLIRContext * context)137 DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
138                                    MLIRContext *context) {
139   if (value.empty())
140     return DictionaryAttr::getEmpty(context);
141   assert(llvm::all_of(value,
142                       [](const NamedAttribute &attr) { return attr.second; }) &&
143          "value cannot have null entries");
144 
145   // We need to sort the element list to canonicalize it.
146   SmallVector<NamedAttribute, 8> storage;
147   if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
148     value = storage;
149   assert(!findDuplicateElement(value) &&
150          "DictionaryAttr element names must be unique");
151   return Base::get(context, value);
152 }
153 /// Construct a dictionary with an array of values that is known to already be
154 /// sorted by name and uniqued.
getWithSorted(ArrayRef<NamedAttribute> value,MLIRContext * context)155 DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef<NamedAttribute> value,
156                                              MLIRContext *context) {
157   if (value.empty())
158     return DictionaryAttr::getEmpty(context);
159   // Ensure that the attribute elements are unique and sorted.
160   assert(llvm::is_sorted(value,
161                          [](NamedAttribute l, NamedAttribute r) {
162                            return l.first.strref() < r.first.strref();
163                          }) &&
164          "expected attribute values to be sorted");
165   assert(!findDuplicateElement(value) &&
166          "DictionaryAttr element names must be unique");
167   return Base::get(context, value);
168 }
169 
getValue() const170 ArrayRef<NamedAttribute> DictionaryAttr::getValue() const {
171   return getImpl()->getElements();
172 }
173 
174 /// Return the specified attribute if present, null otherwise.
get(StringRef name) const175 Attribute DictionaryAttr::get(StringRef name) const {
176   Optional<NamedAttribute> attr = getNamed(name);
177   return attr ? attr->second : nullptr;
178 }
get(Identifier name) const179 Attribute DictionaryAttr::get(Identifier name) const {
180   Optional<NamedAttribute> attr = getNamed(name);
181   return attr ? attr->second : nullptr;
182 }
183 
184 /// Return the specified named attribute if present, None otherwise.
getNamed(StringRef name) const185 Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
186   ArrayRef<NamedAttribute> values = getValue();
187   const auto *it = llvm::lower_bound(values, name);
188   return it != values.end() && it->first == name ? *it
189                                                  : Optional<NamedAttribute>();
190 }
getNamed(Identifier name) const191 Optional<NamedAttribute> DictionaryAttr::getNamed(Identifier name) const {
192   for (auto elt : getValue())
193     if (elt.first == name)
194       return elt;
195   return llvm::None;
196 }
197 
begin() const198 DictionaryAttr::iterator DictionaryAttr::begin() const {
199   return getValue().begin();
200 }
end() const201 DictionaryAttr::iterator DictionaryAttr::end() const {
202   return getValue().end();
203 }
size() const204 size_t DictionaryAttr::size() const { return getValue().size(); }
205 
206 //===----------------------------------------------------------------------===//
207 // FloatAttr
208 //===----------------------------------------------------------------------===//
209 
get(Type type,double value)210 FloatAttr FloatAttr::get(Type type, double value) {
211   return Base::get(type.getContext(), type, value);
212 }
213 
getChecked(Type type,double value,Location loc)214 FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
215   return Base::getChecked(loc, type, value);
216 }
217 
get(Type type,const APFloat & value)218 FloatAttr FloatAttr::get(Type type, const APFloat &value) {
219   return Base::get(type.getContext(), type, value);
220 }
221 
getChecked(Type type,const APFloat & value,Location loc)222 FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
223   return Base::getChecked(loc, type, value);
224 }
225 
getValue() const226 APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
227 
getValueAsDouble() const228 double FloatAttr::getValueAsDouble() const {
229   return getValueAsDouble(getValue());
230 }
getValueAsDouble(APFloat value)231 double FloatAttr::getValueAsDouble(APFloat value) {
232   if (&value.getSemantics() != &APFloat::IEEEdouble()) {
233     bool losesInfo = false;
234     value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
235                   &losesInfo);
236   }
237   return value.convertToDouble();
238 }
239 
240 /// Verify construction invariants.
verifyFloatTypeInvariants(Location loc,Type type)241 static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
242   if (!type.isa<FloatType>())
243     return emitError(loc, "expected floating point type");
244   return success();
245 }
246 
verifyConstructionInvariants(Location loc,Type type,double value)247 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
248                                                       double value) {
249   return verifyFloatTypeInvariants(loc, type);
250 }
251 
verifyConstructionInvariants(Location loc,Type type,const APFloat & value)252 LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
253                                                       const APFloat &value) {
254   // Verify that the type is correct.
255   if (failed(verifyFloatTypeInvariants(loc, type)))
256     return failure();
257 
258   // Verify that the type semantics match that of the value.
259   if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
260     return emitError(
261         loc, "FloatAttr type doesn't match the type implied by its value");
262   }
263   return success();
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // SymbolRefAttr
268 //===----------------------------------------------------------------------===//
269 
get(StringRef value,MLIRContext * ctx)270 FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
271   return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
272 }
273 
get(StringRef value,ArrayRef<FlatSymbolRefAttr> nestedReferences,MLIRContext * ctx)274 SymbolRefAttr SymbolRefAttr::get(StringRef value,
275                                  ArrayRef<FlatSymbolRefAttr> nestedReferences,
276                                  MLIRContext *ctx) {
277   return Base::get(ctx, value, nestedReferences);
278 }
279 
getRootReference() const280 StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
281 
getLeafReference() const282 StringRef SymbolRefAttr::getLeafReference() const {
283   ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
284   return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
285 }
286 
getNestedReferences() const287 ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
288   return getImpl()->getNestedRefs();
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // IntegerAttr
293 //===----------------------------------------------------------------------===//
294 
get(Type type,const APInt & value)295 IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
296   if (type.isSignlessInteger(1))
297     return BoolAttr::get(value.getBoolValue(), type.getContext());
298   return Base::get(type.getContext(), type, value);
299 }
300 
get(Type type,int64_t value)301 IntegerAttr IntegerAttr::get(Type type, int64_t value) {
302   // This uses 64 bit APInts by default for index type.
303   if (type.isIndex())
304     return get(type, APInt(IndexType::kInternalStorageBitWidth, value));
305 
306   auto intType = type.cast<IntegerType>();
307   return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger()));
308 }
309 
getValue() const310 APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
311 
getInt() const312 int64_t IntegerAttr::getInt() const {
313   assert((getImpl()->getType().isIndex() ||
314           getImpl()->getType().isSignlessInteger()) &&
315          "must be signless integer");
316   return getValue().getSExtValue();
317 }
318 
getSInt() const319 int64_t IntegerAttr::getSInt() const {
320   assert(getImpl()->getType().isSignedInteger() && "must be signed integer");
321   return getValue().getSExtValue();
322 }
323 
getUInt() const324 uint64_t IntegerAttr::getUInt() const {
325   assert(getImpl()->getType().isUnsignedInteger() &&
326          "must be unsigned integer");
327   return getValue().getZExtValue();
328 }
329 
verifyIntegerTypeInvariants(Location loc,Type type)330 static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
331   if (type.isa<IntegerType, IndexType>())
332     return success();
333   return emitError(loc, "expected integer or index type");
334 }
335 
verifyConstructionInvariants(Location loc,Type type,int64_t value)336 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
337                                                         int64_t value) {
338   return verifyIntegerTypeInvariants(loc, type);
339 }
340 
verifyConstructionInvariants(Location loc,Type type,const APInt & value)341 LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
342                                                         const APInt &value) {
343   if (failed(verifyIntegerTypeInvariants(loc, type)))
344     return failure();
345   if (auto integerType = type.dyn_cast<IntegerType>())
346     if (integerType.getWidth() != value.getBitWidth())
347       return emitError(loc, "integer type bit width (")
348              << integerType.getWidth() << ") doesn't match value bit width ("
349              << value.getBitWidth() << ")";
350   return success();
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // BoolAttr
355 
getValue() const356 bool BoolAttr::getValue() const {
357   auto *storage = reinterpret_cast<IntegerAttributeStorage *>(impl);
358   return storage->getValue().getBoolValue();
359 }
360 
classof(Attribute attr)361 bool BoolAttr::classof(Attribute attr) {
362   IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
363   return intAttr && intAttr.getType().isSignlessInteger(1);
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // IntegerSetAttr
368 //===----------------------------------------------------------------------===//
369 
get(IntegerSet value)370 IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
371   return Base::get(value.getConstraint(0).getContext(), value);
372 }
373 
getValue() const374 IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
375 
376 //===----------------------------------------------------------------------===//
377 // OpaqueAttr
378 //===----------------------------------------------------------------------===//
379 
get(Identifier dialect,StringRef attrData,Type type,MLIRContext * context)380 OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
381                            MLIRContext *context) {
382   return Base::get(context, dialect, attrData, type);
383 }
384 
getChecked(Identifier dialect,StringRef attrData,Type type,Location location)385 OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
386                                   Type type, Location location) {
387   return Base::getChecked(location, dialect, attrData, type);
388 }
389 
390 /// Returns the dialect namespace of the opaque attribute.
getDialectNamespace() const391 Identifier OpaqueAttr::getDialectNamespace() const {
392   return getImpl()->dialectNamespace;
393 }
394 
395 /// Returns the raw attribute data of the opaque attribute.
getAttrData() const396 StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
397 
398 /// Verify the construction of an opaque attribute.
verifyConstructionInvariants(Location loc,Identifier dialect,StringRef attrData,Type type)399 LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
400                                                        Identifier dialect,
401                                                        StringRef attrData,
402                                                        Type type) {
403   if (!Dialect::isValidNamespace(dialect.strref()))
404     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
405   return success();
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // StringAttr
410 //===----------------------------------------------------------------------===//
411 
get(StringRef bytes,MLIRContext * context)412 StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
413   return get(bytes, NoneType::get(context));
414 }
415 
416 /// Get an instance of a StringAttr with the given string and Type.
get(StringRef bytes,Type type)417 StringAttr StringAttr::get(StringRef bytes, Type type) {
418   return Base::get(type.getContext(), bytes, type);
419 }
420 
getValue() const421 StringRef StringAttr::getValue() const { return getImpl()->value; }
422 
423 //===----------------------------------------------------------------------===//
424 // TypeAttr
425 //===----------------------------------------------------------------------===//
426 
get(Type value)427 TypeAttr TypeAttr::get(Type value) {
428   return Base::get(value.getContext(), value);
429 }
430 
getValue() const431 Type TypeAttr::getValue() const { return getImpl()->value; }
432 
433 //===----------------------------------------------------------------------===//
434 // ElementsAttr
435 //===----------------------------------------------------------------------===//
436 
getType() const437 ShapedType ElementsAttr::getType() const {
438   return Attribute::getType().cast<ShapedType>();
439 }
440 
441 /// Returns the number of elements held by this attribute.
getNumElements() const442 int64_t ElementsAttr::getNumElements() const {
443   return getType().getNumElements();
444 }
445 
446 /// Return the value at the given index. If index does not refer to a valid
447 /// element, then a null attribute is returned.
getValue(ArrayRef<uint64_t> index) const448 Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
449   if (auto denseAttr = dyn_cast<DenseElementsAttr>())
450     return denseAttr.getValue(index);
451   if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
452     return opaqueAttr.getValue(index);
453   return cast<SparseElementsAttr>().getValue(index);
454 }
455 
456 /// Return if the given 'index' refers to a valid element in this attribute.
isValidIndex(ArrayRef<uint64_t> index) const457 bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
458   auto type = getType();
459 
460   // Verify that the rank of the indices matches the held type.
461   auto rank = type.getRank();
462   if (rank != static_cast<int64_t>(index.size()))
463     return false;
464 
465   // Verify that all of the indices are within the shape dimensions.
466   auto shape = type.getShape();
467   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
468     return static_cast<int64_t>(index[i]) < shape[i];
469   });
470 }
471 
472 ElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const473 ElementsAttr::mapValues(Type newElementType,
474                         function_ref<APInt(const APInt &)> mapping) const {
475   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
476     return intOrFpAttr.mapValues(newElementType, mapping);
477   llvm_unreachable("unsupported ElementsAttr subtype");
478 }
479 
480 ElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const481 ElementsAttr::mapValues(Type newElementType,
482                         function_ref<APInt(const APFloat &)> mapping) const {
483   if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
484     return intOrFpAttr.mapValues(newElementType, mapping);
485   llvm_unreachable("unsupported ElementsAttr subtype");
486 }
487 
488 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)489 bool ElementsAttr::classof(Attribute attr) {
490   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
491                   OpaqueElementsAttr, SparseElementsAttr>();
492 }
493 
494 /// Returns the 1 dimensional flattened row-major index from the given
495 /// multi-dimensional index.
getFlattenedIndex(ArrayRef<uint64_t> index) const496 uint64_t ElementsAttr::getFlattenedIndex(ArrayRef<uint64_t> index) const {
497   assert(isValidIndex(index) && "expected valid multi-dimensional index");
498   auto type = getType();
499 
500   // Reduce the provided multidimensional index into a flattended 1D row-major
501   // index.
502   auto rank = type.getRank();
503   auto shape = type.getShape();
504   uint64_t valueIndex = 0;
505   uint64_t dimMultiplier = 1;
506   for (int i = rank - 1; i >= 0; --i) {
507     valueIndex += index[i] * dimMultiplier;
508     dimMultiplier *= shape[i];
509   }
510   return valueIndex;
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // DenseElementsAttr Utilities
515 //===----------------------------------------------------------------------===//
516 
517 /// Get the bitwidth of a dense element type within the buffer.
518 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
getDenseElementStorageWidth(size_t origWidth)519 static size_t getDenseElementStorageWidth(size_t origWidth) {
520   return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
521 }
getDenseElementStorageWidth(Type elementType)522 static size_t getDenseElementStorageWidth(Type elementType) {
523   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
524 }
525 
526 /// Set a bit to a specific value.
setBit(char * rawData,size_t bitPos,bool value)527 static void setBit(char *rawData, size_t bitPos, bool value) {
528   if (value)
529     rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
530   else
531     rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
532 }
533 
534 /// Return the value of the specified bit.
getBit(const char * rawData,size_t bitPos)535 static bool getBit(const char *rawData, size_t bitPos) {
536   return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
537 }
538 
539 /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
540 /// BE format.
copyAPIntToArrayForBEmachine(APInt value,size_t numBytes,char * result)541 static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
542                                          char *result) {
543   assert(llvm::support::endian::system_endianness() == // NOLINT
544          llvm::support::endianness::big);              // NOLINT
545   assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
546 
547   // Copy the words filled with data.
548   // For example, when `value` has 2 words, the first word is filled with data.
549   // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
550   size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
551   std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
552               numFilledWords, result);
553   // Convert last word of APInt to LE format and store it in char
554   // array(`valueLE`).
555   // ex. last word of `value` (BE): |------ij|  ==> `valueLE` (LE): |ji------|
556   size_t lastWordPos = numFilledWords;
557   SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
558   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
559       reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
560       valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
561   // Extract actual APInt data from `valueLE`, convert endianness to BE format,
562   // and store it in `result`.
563   // ex. `valueLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|ij|
564   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
565       valueLE.begin(), result + lastWordPos,
566       (numBytes - lastWordPos) * CHAR_BIT, 1);
567 }
568 
569 /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
570 /// format.
copyArrayToAPIntForBEmachine(const char * inArray,size_t numBytes,APInt & result)571 static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
572                                          APInt &result) {
573   assert(llvm::support::endian::system_endianness() == // NOLINT
574          llvm::support::endianness::big);              // NOLINT
575   assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
576 
577   // Copy the data that fills the word of `result` from `inArray`.
578   // For example, when `result` has 2 words, the first word will be filled with
579   // data. So, the first 8 bytes are copied from `inArray` here.
580   // `inArray` (10 bytes, BE): |abcdefgh|ij|
581   //                     ==> `result` (2 words, BE): |abcdefgh|--------|
582   size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
583   std::copy_n(
584       inArray, numFilledWords,
585       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
586 
587   // Convert array data which will be last word of `result` to LE format, and
588   // store it in char array(`inArrayLE`).
589   // ex. `inArray` (last two bytes, BE): |ij|  ==> `inArrayLE` (LE): |ji------|
590   size_t lastWordPos = numFilledWords;
591   SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
592   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
593       inArray + lastWordPos, inArrayLE.begin(),
594       (numBytes - lastWordPos) * CHAR_BIT, 1);
595 
596   // Convert `inArrayLE` to BE format, and store it in last word of `result`.
597   // ex. `inArrayLE` (LE): |ji------|  ==> `result` (BE): |abcdefgh|------ij|
598   DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
599       inArrayLE.begin(),
600       const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
601           lastWordPos,
602       APInt::APINT_BITS_PER_WORD, 1);
603 }
604 
605 /// Writes value to the bit position `bitPos` in array `rawData`.
writeBits(char * rawData,size_t bitPos,APInt value)606 static void writeBits(char *rawData, size_t bitPos, APInt value) {
607   size_t bitWidth = value.getBitWidth();
608 
609   // If the bitwidth is 1 we just toggle the specific bit.
610   if (bitWidth == 1)
611     return setBit(rawData, bitPos, value.isOneValue());
612 
613   // Otherwise, the bit position is guaranteed to be byte aligned.
614   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
615   if (llvm::support::endian::system_endianness() ==
616       llvm::support::endianness::big) {
617     // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
618     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
619     // work correctly in BE format.
620     // ex. `value` (2 words including 10 bytes)
621     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------|
622     copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT),
623                                  rawData + (bitPos / CHAR_BIT));
624   } else {
625     std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
626                 llvm::divideCeil(bitWidth, CHAR_BIT),
627                 rawData + (bitPos / CHAR_BIT));
628   }
629 }
630 
631 /// Reads the next `bitWidth` bits from the bit position `bitPos` in array
632 /// `rawData`.
readBits(const char * rawData,size_t bitPos,size_t bitWidth)633 static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
634   // Handle a boolean bit position.
635   if (bitWidth == 1)
636     return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
637 
638   // Otherwise, the bit position must be 8-bit aligned.
639   assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
640   APInt result(bitWidth, 0);
641   if (llvm::support::endian::system_endianness() ==
642       llvm::support::endianness::big) {
643     // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
644     // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
645     // work correctly in BE format.
646     // ex. `result` (2 words including 10 bytes)
647     // ==> BE: |abcdefgh|------ij|,  LE: |hgfedcba|ji------| This function
648     copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT),
649                                  llvm::divideCeil(bitWidth, CHAR_BIT), result);
650   } else {
651     std::copy_n(rawData + (bitPos / CHAR_BIT),
652                 llvm::divideCeil(bitWidth, CHAR_BIT),
653                 const_cast<char *>(
654                     reinterpret_cast<const char *>(result.getRawData())));
655   }
656   return result;
657 }
658 
659 /// Returns true if 'values' corresponds to a splat, i.e. one element, or has
660 /// the same element count as 'type'.
661 template <typename Values>
hasSameElementsOrSplat(ShapedType type,const Values & values)662 static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
663   return (values.size() == 1) ||
664          (type.getNumElements() == static_cast<int64_t>(values.size()));
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // DenseElementsAttr Iterators
669 //===----------------------------------------------------------------------===//
670 
671 //===----------------------------------------------------------------------===//
672 // AttributeElementIterator
673 
AttributeElementIterator(DenseElementsAttr attr,size_t index)674 DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
675     DenseElementsAttr attr, size_t index)
676     : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
677                                       Attribute, Attribute, Attribute>(
678           attr.getAsOpaquePointer(), index) {}
679 
operator *() const680 Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
681   auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
682   Type eltTy = owner.getType().getElementType();
683   if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
684     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
685   if (eltTy.isa<IndexType>())
686     return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
687   if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
688     IntElementIterator intIt(owner, index);
689     FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
690     return FloatAttr::get(eltTy, *floatIt);
691   }
692   if (owner.isa<DenseStringElementsAttr>()) {
693     ArrayRef<StringRef> vals = owner.getRawStringData();
694     return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
695   }
696   llvm_unreachable("unexpected element type");
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // BoolElementIterator
701 
BoolElementIterator(DenseElementsAttr attr,size_t dataIndex)702 DenseElementsAttr::BoolElementIterator::BoolElementIterator(
703     DenseElementsAttr attr, size_t dataIndex)
704     : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
705           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
706 
operator *() const707 bool DenseElementsAttr::BoolElementIterator::operator*() const {
708   return getBit(getData(), getDataIndex());
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // IntElementIterator
713 
IntElementIterator(DenseElementsAttr attr,size_t dataIndex)714 DenseElementsAttr::IntElementIterator::IntElementIterator(
715     DenseElementsAttr attr, size_t dataIndex)
716     : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
717           attr.getRawData().data(), attr.isSplat(), dataIndex),
718       bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {}
719 
operator *() const720 APInt DenseElementsAttr::IntElementIterator::operator*() const {
721   return readBits(getData(),
722                   getDataIndex() * getDenseElementStorageWidth(bitWidth),
723                   bitWidth);
724 }
725 
726 //===----------------------------------------------------------------------===//
727 // ComplexIntElementIterator
728 
ComplexIntElementIterator(DenseElementsAttr attr,size_t dataIndex)729 DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
730     DenseElementsAttr attr, size_t dataIndex)
731     : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
732                                       std::complex<APInt>, std::complex<APInt>,
733                                       std::complex<APInt>>(
734           attr.getRawData().data(), attr.isSplat(), dataIndex) {
735   auto complexType = attr.getType().getElementType().cast<ComplexType>();
736   bitWidth = getDenseElementBitWidth(complexType.getElementType());
737 }
738 
739 std::complex<APInt>
operator *() const740 DenseElementsAttr::ComplexIntElementIterator::operator*() const {
741   size_t storageWidth = getDenseElementStorageWidth(bitWidth);
742   size_t offset = getDataIndex() * storageWidth * 2;
743   return {readBits(getData(), offset, bitWidth),
744           readBits(getData(), offset + storageWidth, bitWidth)};
745 }
746 
747 //===----------------------------------------------------------------------===//
748 // FloatElementIterator
749 
FloatElementIterator(const llvm::fltSemantics & smt,IntElementIterator it)750 DenseElementsAttr::FloatElementIterator::FloatElementIterator(
751     const llvm::fltSemantics &smt, IntElementIterator it)
752     : llvm::mapped_iterator<IntElementIterator,
753                             std::function<APFloat(const APInt &)>>(
754           it, [&](const APInt &val) { return APFloat(smt, val); }) {}
755 
756 //===----------------------------------------------------------------------===//
757 // ComplexFloatElementIterator
758 
ComplexFloatElementIterator(const llvm::fltSemantics & smt,ComplexIntElementIterator it)759 DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
760     const llvm::fltSemantics &smt, ComplexIntElementIterator it)
761     : llvm::mapped_iterator<
762           ComplexIntElementIterator,
763           std::function<std::complex<APFloat>(const std::complex<APInt> &)>>(
764           it, [&](const std::complex<APInt> &val) -> std::complex<APFloat> {
765             return {APFloat(smt, val.real()), APFloat(smt, val.imag())};
766           }) {}
767 
768 //===----------------------------------------------------------------------===//
769 // DenseElementsAttr
770 //===----------------------------------------------------------------------===//
771 
772 /// Method for support type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)773 bool DenseElementsAttr::classof(Attribute attr) {
774   return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
775 }
776 
get(ShapedType type,ArrayRef<Attribute> values)777 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
778                                          ArrayRef<Attribute> values) {
779   assert(hasSameElementsOrSplat(type, values));
780 
781   // If the element type is not based on int/float/index, assume it is a string
782   // type.
783   auto eltType = type.getElementType();
784   if (!type.getElementType().isIntOrIndexOrFloat()) {
785     SmallVector<StringRef, 8> stringValues;
786     stringValues.reserve(values.size());
787     for (Attribute attr : values) {
788       assert(attr.isa<StringAttr>() &&
789              "expected string value for non integer/index/float element");
790       stringValues.push_back(attr.cast<StringAttr>().getValue());
791     }
792     return get(type, stringValues);
793   }
794 
795   // Otherwise, get the raw storage width to use for the allocation.
796   size_t bitWidth = getDenseElementBitWidth(eltType);
797   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
798 
799   // Compress the attribute values into a character buffer.
800   SmallVector<char, 8> data(llvm::divideCeil(storageBitWidth, CHAR_BIT) *
801                             values.size());
802   APInt intVal;
803   for (unsigned i = 0, e = values.size(); i < e; ++i) {
804     assert(eltType == values[i].getType() &&
805            "expected attribute value to have element type");
806     if (eltType.isa<FloatType>())
807       intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
808     else if (eltType.isa<IntegerType>())
809       intVal = values[i].cast<IntegerAttr>().getValue();
810     else
811       llvm_unreachable("unexpected element type");
812 
813     assert(intVal.getBitWidth() == bitWidth &&
814            "expected value to have same bitwidth as element type");
815     writeBits(data.data(), i * storageBitWidth, intVal);
816   }
817   return DenseIntOrFPElementsAttr::getRaw(type, data,
818                                           /*isSplat=*/(values.size() == 1));
819 }
820 
get(ShapedType type,ArrayRef<bool> values)821 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
822                                          ArrayRef<bool> values) {
823   assert(hasSameElementsOrSplat(type, values));
824   assert(type.getElementType().isInteger(1));
825 
826   std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
827   for (int i = 0, e = values.size(); i != e; ++i)
828     setBit(buff.data(), i, values[i]);
829   return DenseIntOrFPElementsAttr::getRaw(type, buff,
830                                           /*isSplat=*/(values.size() == 1));
831 }
832 
get(ShapedType type,ArrayRef<StringRef> values)833 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
834                                          ArrayRef<StringRef> values) {
835   assert(!type.getElementType().isIntOrFloat());
836   return DenseStringElementsAttr::get(type, values);
837 }
838 
839 /// Constructs a dense integer elements attribute from an array of APInt
840 /// values. Each APInt value is expected to have the same bitwidth as the
841 /// element type of 'type'.
get(ShapedType type,ArrayRef<APInt> values)842 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
843                                          ArrayRef<APInt> values) {
844   assert(type.getElementType().isIntOrIndex());
845   assert(hasSameElementsOrSplat(type, values));
846   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
847   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
848                                           /*isSplat=*/(values.size() == 1));
849 }
get(ShapedType type,ArrayRef<std::complex<APInt>> values)850 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
851                                          ArrayRef<std::complex<APInt>> values) {
852   ComplexType complex = type.getElementType().cast<ComplexType>();
853   assert(complex.getElementType().isa<IntegerType>());
854   assert(hasSameElementsOrSplat(type, values));
855   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
856   ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
857                           values.size() * 2);
858   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals,
859                                           /*isSplat=*/(values.size() == 1));
860 }
861 
862 // Constructs a dense float elements attribute from an array of APFloat
863 // values. Each APFloat value is expected to have the same bitwidth as the
864 // element type of 'type'.
get(ShapedType type,ArrayRef<APFloat> values)865 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
866                                          ArrayRef<APFloat> values) {
867   assert(type.getElementType().isa<FloatType>());
868   assert(hasSameElementsOrSplat(type, values));
869   size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
870   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values,
871                                           /*isSplat=*/(values.size() == 1));
872 }
873 DenseElementsAttr
get(ShapedType type,ArrayRef<std::complex<APFloat>> values)874 DenseElementsAttr::get(ShapedType type,
875                        ArrayRef<std::complex<APFloat>> values) {
876   ComplexType complex = type.getElementType().cast<ComplexType>();
877   assert(complex.getElementType().isa<FloatType>());
878   assert(hasSameElementsOrSplat(type, values));
879   ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
880                            values.size() * 2);
881   size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
882   return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals,
883                                           /*isSplat=*/(values.size() == 1));
884 }
885 
886 /// Construct a dense elements attribute from a raw buffer representing the
887 /// data for this attribute. Users should generally not use this methods as
888 /// the expected buffer format may not be a form the user expects.
getFromRawBuffer(ShapedType type,ArrayRef<char> rawBuffer,bool isSplatBuffer)889 DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
890                                                       ArrayRef<char> rawBuffer,
891                                                       bool isSplatBuffer) {
892   return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer, isSplatBuffer);
893 }
894 
895 /// Returns true if the given buffer is a valid raw buffer for the given type.
isValidRawBuffer(ShapedType type,ArrayRef<char> rawBuffer,bool & detectedSplat)896 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
897                                          ArrayRef<char> rawBuffer,
898                                          bool &detectedSplat) {
899   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
900   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
901 
902   // Storage width of 1 is special as it is packed by the bit.
903   if (storageWidth == 1) {
904     // Check for a splat, or a buffer equal to the number of elements.
905     if ((detectedSplat = rawBuffer.size() == 1))
906       return true;
907     return rawBufferWidth == llvm::alignTo<8>(type.getNumElements());
908   }
909   // All other types are 8-bit aligned.
910   if ((detectedSplat = rawBufferWidth == storageWidth))
911     return true;
912   return rawBufferWidth == (storageWidth * type.getNumElements());
913 }
914 
915 /// Check the information for a C++ data type, check if this type is valid for
916 /// the current attribute. This method is used to verify specific type
917 /// invariants that the templatized 'getValues' method cannot.
isValidIntOrFloat(Type type,int64_t dataEltSize,bool isInt,bool isSigned)918 static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
919                               bool isSigned) {
920   // Make sure that the data element size is the same as the type element width.
921   if (getDenseElementBitWidth(type) !=
922       static_cast<size_t>(dataEltSize * CHAR_BIT))
923     return false;
924 
925   // Check that the element type is either float or integer or index.
926   if (!isInt)
927     return type.isa<FloatType>();
928   if (type.isIndex())
929     return true;
930 
931   auto intType = type.dyn_cast<IntegerType>();
932   if (!intType)
933     return false;
934 
935   // Make sure signedness semantics is consistent.
936   if (intType.isSignless())
937     return true;
938   return intType.isSigned() ? isSigned : !isSigned;
939 }
940 
941 /// Defaults down the subclass implementation.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)942 DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
943                                                    ArrayRef<char> data,
944                                                    int64_t dataEltSize,
945                                                    bool isInt, bool isSigned) {
946   return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
947                                                  isSigned);
948 }
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)949 DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
950                                                       ArrayRef<char> data,
951                                                       int64_t dataEltSize,
952                                                       bool isInt,
953                                                       bool isSigned) {
954   return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
955                                                     isInt, isSigned);
956 }
957 
958 /// A method used to verify specific type invariants that the templatized 'get'
959 /// method cannot.
isValidIntOrFloat(int64_t dataEltSize,bool isInt,bool isSigned) const960 bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
961                                           bool isSigned) const {
962   return ::isValidIntOrFloat(getType().getElementType(), dataEltSize, isInt,
963                              isSigned);
964 }
965 
966 /// Check the information for a C++ data type, check if this type is valid for
967 /// the current attribute.
isValidComplex(int64_t dataEltSize,bool isInt,bool isSigned) const968 bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
969                                        bool isSigned) const {
970   return ::isValidIntOrFloat(
971       getType().getElementType().cast<ComplexType>().getElementType(),
972       dataEltSize / 2, isInt, isSigned);
973 }
974 
975 /// Returns true if this attribute corresponds to a splat, i.e. if all element
976 /// values are the same.
isSplat() const977 bool DenseElementsAttr::isSplat() const {
978   return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
979 }
980 
981 /// Return the held element values as a range of Attributes.
getAttributeValues() const982 auto DenseElementsAttr::getAttributeValues() const
983     -> llvm::iterator_range<AttributeElementIterator> {
984   return {attr_value_begin(), attr_value_end()};
985 }
attr_value_begin() const986 auto DenseElementsAttr::attr_value_begin() const -> AttributeElementIterator {
987   return AttributeElementIterator(*this, 0);
988 }
attr_value_end() const989 auto DenseElementsAttr::attr_value_end() const -> AttributeElementIterator {
990   return AttributeElementIterator(*this, getNumElements());
991 }
992 
993 /// Return the held element values as a range of bool. The element type of
994 /// this attribute must be of integer type of bitwidth 1.
getBoolValues() const995 auto DenseElementsAttr::getBoolValues() const
996     -> llvm::iterator_range<BoolElementIterator> {
997   auto eltType = getType().getElementType().dyn_cast<IntegerType>();
998   assert(eltType && eltType.getWidth() == 1 && "expected i1 integer type");
999   (void)eltType;
1000   return {BoolElementIterator(*this, 0),
1001           BoolElementIterator(*this, getNumElements())};
1002 }
1003 
1004 /// Return the held element values as a range of APInts. The element type of
1005 /// this attribute must be of integer type.
getIntValues() const1006 auto DenseElementsAttr::getIntValues() const
1007     -> llvm::iterator_range<IntElementIterator> {
1008   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1009   return {raw_int_begin(), raw_int_end()};
1010 }
int_value_begin() const1011 auto DenseElementsAttr::int_value_begin() const -> IntElementIterator {
1012   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1013   return raw_int_begin();
1014 }
int_value_end() const1015 auto DenseElementsAttr::int_value_end() const -> IntElementIterator {
1016   assert(getType().getElementType().isIntOrIndex() && "expected integral type");
1017   return raw_int_end();
1018 }
getComplexIntValues() const1019 auto DenseElementsAttr::getComplexIntValues() const
1020     -> llvm::iterator_range<ComplexIntElementIterator> {
1021   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1022   (void)eltTy;
1023   assert(eltTy.isa<IntegerType>() && "expected complex integral type");
1024   return {ComplexIntElementIterator(*this, 0),
1025           ComplexIntElementIterator(*this, getNumElements())};
1026 }
1027 
1028 /// Return the held element values as a range of APFloat. The element type of
1029 /// this attribute must be of float type.
getFloatValues() const1030 auto DenseElementsAttr::getFloatValues() const
1031     -> llvm::iterator_range<FloatElementIterator> {
1032   auto elementType = getType().getElementType().cast<FloatType>();
1033   const auto &elementSemantics = elementType.getFloatSemantics();
1034   return {FloatElementIterator(elementSemantics, raw_int_begin()),
1035           FloatElementIterator(elementSemantics, raw_int_end())};
1036 }
float_value_begin() const1037 auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1038   return getFloatValues().begin();
1039 }
float_value_end() const1040 auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1041   return getFloatValues().end();
1042 }
getComplexFloatValues() const1043 auto DenseElementsAttr::getComplexFloatValues() const
1044     -> llvm::iterator_range<ComplexFloatElementIterator> {
1045   Type eltTy = getType().getElementType().cast<ComplexType>().getElementType();
1046   assert(eltTy.isa<FloatType>() && "expected complex float type");
1047   const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1048   return {{semantics, {*this, 0}},
1049           {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1050 }
1051 
1052 /// Return the raw storage data held by this attribute.
getRawData() const1053 ArrayRef<char> DenseElementsAttr::getRawData() const {
1054   return static_cast<DenseIntOrFPElementsAttributeStorage *>(impl)->data;
1055 }
1056 
getRawStringData() const1057 ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1058   return static_cast<DenseStringElementsAttributeStorage *>(impl)->data;
1059 }
1060 
1061 /// Return a new DenseElementsAttr that has the same data as the current
1062 /// attribute, but has been reshaped to 'newType'. The new type must have the
1063 /// same total number of elements as well as element type.
reshape(ShapedType newType)1064 DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1065   ShapedType curType = getType();
1066   if (curType == newType)
1067     return *this;
1068 
1069   (void)curType;
1070   assert(newType.getElementType() == curType.getElementType() &&
1071          "expected the same element type");
1072   assert(newType.getNumElements() == curType.getNumElements() &&
1073          "expected the same number of elements");
1074   return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
1075 }
1076 
1077 DenseElementsAttr
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1078 DenseElementsAttr::mapValues(Type newElementType,
1079                              function_ref<APInt(const APInt &)> mapping) const {
1080   return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1081 }
1082 
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1083 DenseElementsAttr DenseElementsAttr::mapValues(
1084     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1085   return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // DenseStringElementsAttr
1090 //===----------------------------------------------------------------------===//
1091 
1092 DenseStringElementsAttr
get(ShapedType type,ArrayRef<StringRef> values)1093 DenseStringElementsAttr::get(ShapedType type, ArrayRef<StringRef> values) {
1094   return Base::get(type.getContext(), type, values, (values.size() == 1));
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // DenseIntOrFPElementsAttr
1099 //===----------------------------------------------------------------------===//
1100 
1101 /// Utility method to write a range of APInt values to a buffer.
1102 template <typename APRangeT>
writeAPIntsToBuffer(size_t storageWidth,std::vector<char> & data,APRangeT && values)1103 static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1104                                 APRangeT &&values) {
1105   data.resize(llvm::divideCeil(storageWidth, CHAR_BIT) * llvm::size(values));
1106   size_t offset = 0;
1107   for (auto it = values.begin(), e = values.end(); it != e;
1108        ++it, offset += storageWidth) {
1109     assert((*it).getBitWidth() <= storageWidth);
1110     writeBits(data.data(), offset, *it);
1111   }
1112 }
1113 
1114 /// Constructs a dense elements attribute from an array of raw APFloat values.
1115 /// Each APFloat value is expected to have the same bitwidth as the element
1116 /// type of 'type'. 'type' must be a vector or tensor with static shape.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APFloat> values,bool isSplat)1117 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1118                                                    size_t storageWidth,
1119                                                    ArrayRef<APFloat> values,
1120                                                    bool isSplat) {
1121   std::vector<char> data;
1122   auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1123   writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1124   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1125 }
1126 
1127 /// Constructs a dense elements attribute from an array of raw APInt values.
1128 /// Each APInt value is expected to have the same bitwidth as the element type
1129 /// of 'type'.
getRaw(ShapedType type,size_t storageWidth,ArrayRef<APInt> values,bool isSplat)1130 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1131                                                    size_t storageWidth,
1132                                                    ArrayRef<APInt> values,
1133                                                    bool isSplat) {
1134   std::vector<char> data;
1135   writeAPIntsToBuffer(storageWidth, data, values);
1136   return DenseIntOrFPElementsAttr::getRaw(type, data, isSplat);
1137 }
1138 
getRaw(ShapedType type,ArrayRef<char> data,bool isSplat)1139 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1140                                                    ArrayRef<char> data,
1141                                                    bool isSplat) {
1142   assert((type.isa<RankedTensorType, VectorType>()) &&
1143          "type must be ranked tensor or vector");
1144   assert(type.hasStaticShape() && "type must have static shape");
1145   return Base::get(type.getContext(), type, data, isSplat);
1146 }
1147 
1148 /// Overload of the raw 'get' method that asserts that the given type is of
1149 /// complex type. This method is used to verify type invariants that the
1150 /// templatized 'get' method cannot.
getRawComplex(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1151 DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1152                                                           ArrayRef<char> data,
1153                                                           int64_t dataEltSize,
1154                                                           bool isInt,
1155                                                           bool isSigned) {
1156   assert(::isValidIntOrFloat(
1157       type.getElementType().cast<ComplexType>().getElementType(),
1158       dataEltSize / 2, isInt, isSigned));
1159 
1160   int64_t numElements = data.size() / dataEltSize;
1161   assert(numElements == 1 || numElements == type.getNumElements());
1162   return getRaw(type, data, /*isSplat=*/numElements == 1);
1163 }
1164 
1165 /// Overload of the 'getRaw' method that asserts that the given type is of
1166 /// integer type. This method is used to verify type invariants that the
1167 /// templatized 'get' method cannot.
1168 DenseElementsAttr
getRawIntOrFloat(ShapedType type,ArrayRef<char> data,int64_t dataEltSize,bool isInt,bool isSigned)1169 DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1170                                            int64_t dataEltSize, bool isInt,
1171                                            bool isSigned) {
1172   assert(
1173       ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned));
1174 
1175   int64_t numElements = data.size() / dataEltSize;
1176   assert(numElements == 1 || numElements == type.getNumElements());
1177   return getRaw(type, data, /*isSplat=*/numElements == 1);
1178 }
1179 
convertEndianOfCharForBEmachine(const char * inRawData,char * outRawData,size_t elementBitWidth,size_t numElements)1180 void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1181     const char *inRawData, char *outRawData, size_t elementBitWidth,
1182     size_t numElements) {
1183   using llvm::support::ulittle16_t;
1184   using llvm::support::ulittle32_t;
1185   using llvm::support::ulittle64_t;
1186 
1187   assert(llvm::support::endian::system_endianness() == // NOLINT
1188          llvm::support::endianness::big);              // NOLINT
1189   // NOLINT to avoid warning message about replacing by static_assert()
1190 
1191   // Following std::copy_n always converts endianness on BE machine.
1192   switch (elementBitWidth) {
1193   case 16: {
1194     const ulittle16_t *inRawDataPos =
1195         reinterpret_cast<const ulittle16_t *>(inRawData);
1196     uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1197     std::copy_n(inRawDataPos, numElements, outDataPos);
1198     break;
1199   }
1200   case 32: {
1201     const ulittle32_t *inRawDataPos =
1202         reinterpret_cast<const ulittle32_t *>(inRawData);
1203     uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1204     std::copy_n(inRawDataPos, numElements, outDataPos);
1205     break;
1206   }
1207   case 64: {
1208     const ulittle64_t *inRawDataPos =
1209         reinterpret_cast<const ulittle64_t *>(inRawData);
1210     uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1211     std::copy_n(inRawDataPos, numElements, outDataPos);
1212     break;
1213   }
1214   default: {
1215     size_t nBytes = elementBitWidth / CHAR_BIT;
1216     for (size_t i = 0; i < nBytes; i++)
1217       std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1218     break;
1219   }
1220   }
1221 }
1222 
convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData,MutableArrayRef<char> outRawData,ShapedType type)1223 void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1224     ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1225     ShapedType type) {
1226   size_t numElements = type.getNumElements();
1227   Type elementType = type.getElementType();
1228   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1229     elementType = complexTy.getElementType();
1230     numElements = numElements * 2;
1231   }
1232   size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1233   assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1234          inRawData.size() <= outRawData.size());
1235   convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1236                                   elementBitWidth, numElements);
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // DenseFPElementsAttr
1241 //===----------------------------------------------------------------------===//
1242 
1243 template <typename Fn, typename Attr>
mappingHelper(Fn mapping,Attr & attr,ShapedType inType,Type newElementType,llvm::SmallVectorImpl<char> & data)1244 static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1245                                 Type newElementType,
1246                                 llvm::SmallVectorImpl<char> &data) {
1247   size_t bitWidth = getDenseElementBitWidth(newElementType);
1248   size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1249 
1250   ShapedType newArrayType;
1251   if (inType.isa<RankedTensorType>())
1252     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1253   else if (inType.isa<UnrankedTensorType>())
1254     newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1255   else if (inType.isa<VectorType>())
1256     newArrayType = VectorType::get(inType.getShape(), newElementType);
1257   else
1258     assert(newArrayType && "Unhandled tensor type");
1259 
1260   size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1261   data.resize(llvm::divideCeil(storageBitWidth, CHAR_BIT) * numRawElements);
1262 
1263   // Functor used to process a single element value of the attribute.
1264   auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1265     auto newInt = mapping(value);
1266     assert(newInt.getBitWidth() == bitWidth);
1267     writeBits(data.data(), index * storageBitWidth, newInt);
1268   };
1269 
1270   // Check for the splat case.
1271   if (attr.isSplat()) {
1272     processElt(*attr.begin(), /*index=*/0);
1273     return newArrayType;
1274   }
1275 
1276   // Otherwise, process all of the element values.
1277   uint64_t elementIdx = 0;
1278   for (auto value : attr)
1279     processElt(value, elementIdx++);
1280   return newArrayType;
1281 }
1282 
mapValues(Type newElementType,function_ref<APInt (const APFloat &)> mapping) const1283 DenseElementsAttr DenseFPElementsAttr::mapValues(
1284     Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1285   llvm::SmallVector<char, 8> elementData;
1286   auto newArrayType =
1287       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1288 
1289   return getRaw(newArrayType, elementData, isSplat());
1290 }
1291 
1292 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1293 bool DenseFPElementsAttr::classof(Attribute attr) {
1294   return attr.isa<DenseElementsAttr>() &&
1295          attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
1296 }
1297 
1298 //===----------------------------------------------------------------------===//
1299 // DenseIntElementsAttr
1300 //===----------------------------------------------------------------------===//
1301 
mapValues(Type newElementType,function_ref<APInt (const APInt &)> mapping) const1302 DenseElementsAttr DenseIntElementsAttr::mapValues(
1303     Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1304   llvm::SmallVector<char, 8> elementData;
1305   auto newArrayType =
1306       mappingHelper(mapping, *this, getType(), newElementType, elementData);
1307 
1308   return getRaw(newArrayType, elementData, isSplat());
1309 }
1310 
1311 /// Method for supporting type inquiry through isa, cast and dyn_cast.
classof(Attribute attr)1312 bool DenseIntElementsAttr::classof(Attribute attr) {
1313   return attr.isa<DenseElementsAttr>() &&
1314          attr.getType().cast<ShapedType>().getElementType().isIntOrIndex();
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // OpaqueElementsAttr
1319 //===----------------------------------------------------------------------===//
1320 
get(Dialect * dialect,ShapedType type,StringRef bytes)1321 OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type,
1322                                            StringRef bytes) {
1323   assert(TensorType::isValidElementType(type.getElementType()) &&
1324          "Input element type should be a valid tensor element type");
1325   return Base::get(type.getContext(), type, dialect, bytes);
1326 }
1327 
getValue() const1328 StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
1329 
1330 /// Return the value at the given index. If index does not refer to a valid
1331 /// element, then a null attribute is returned.
getValue(ArrayRef<uint64_t> index) const1332 Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1333   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1334   return Attribute();
1335 }
1336 
getDialect() const1337 Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
1338 
decode(ElementsAttr & result)1339 bool OpaqueElementsAttr::decode(ElementsAttr &result) {
1340   auto *d = getDialect();
1341   if (!d)
1342     return true;
1343   auto *interface =
1344       d->getRegisteredInterface<DialectDecodeAttributesInterface>();
1345   if (!interface)
1346     return true;
1347   return failed(interface->decode(*this, result));
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // SparseElementsAttr
1352 //===----------------------------------------------------------------------===//
1353 
get(ShapedType type,DenseElementsAttr indices,DenseElementsAttr values)1354 SparseElementsAttr SparseElementsAttr::get(ShapedType type,
1355                                            DenseElementsAttr indices,
1356                                            DenseElementsAttr values) {
1357   assert(indices.getType().getElementType().isInteger(64) &&
1358          "expected sparse indices to be 64-bit integer values");
1359   assert((type.isa<RankedTensorType, VectorType>()) &&
1360          "type must be ranked tensor or vector");
1361   assert(type.hasStaticShape() && "type must have static shape");
1362   return Base::get(type.getContext(), type,
1363                    indices.cast<DenseIntElementsAttr>(), values);
1364 }
1365 
getIndices() const1366 DenseIntElementsAttr SparseElementsAttr::getIndices() const {
1367   return getImpl()->indices;
1368 }
1369 
getValues() const1370 DenseElementsAttr SparseElementsAttr::getValues() const {
1371   return getImpl()->values;
1372 }
1373 
1374 /// Return the value of the element at the given index.
getValue(ArrayRef<uint64_t> index) const1375 Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
1376   assert(isValidIndex(index) && "expected valid multi-dimensional index");
1377   auto type = getType();
1378 
1379   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1380   // as a 1-D index array.
1381   auto sparseIndices = getIndices();
1382   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1383 
1384   // Check to see if the indices are a splat.
1385   if (sparseIndices.isSplat()) {
1386     // If the index is also not a splat of the index value, we know that the
1387     // value is zero.
1388     auto splatIndex = *sparseIndexValues.begin();
1389     if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
1390       return getZeroAttr();
1391 
1392     // If the indices are a splat, we also expect the values to be a splat.
1393     assert(getValues().isSplat() && "expected splat values");
1394     return getValues().getSplatValue();
1395   }
1396 
1397   // Build a mapping between known indices and the offset of the stored element.
1398   llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
1399   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1400   size_t rank = type.getRank();
1401   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1402     mappedIndices.try_emplace(
1403         {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
1404 
1405   // Look for the provided index key within the mapped indices. If the provided
1406   // index is not found, then return a zero attribute.
1407   auto it = mappedIndices.find(index);
1408   if (it == mappedIndices.end())
1409     return getZeroAttr();
1410 
1411   // Otherwise, return the held sparse value element.
1412   return getValues().getValue(it->second);
1413 }
1414 
1415 /// Get a zero APFloat for the given sparse attribute.
getZeroAPFloat() const1416 APFloat SparseElementsAttr::getZeroAPFloat() const {
1417   auto eltType = getType().getElementType().cast<FloatType>();
1418   return APFloat(eltType.getFloatSemantics());
1419 }
1420 
1421 /// Get a zero APInt for the given sparse attribute.
getZeroAPInt() const1422 APInt SparseElementsAttr::getZeroAPInt() const {
1423   auto eltType = getType().getElementType().cast<IntegerType>();
1424   return APInt::getNullValue(eltType.getWidth());
1425 }
1426 
1427 /// Get a zero attribute for the given attribute type.
getZeroAttr() const1428 Attribute SparseElementsAttr::getZeroAttr() const {
1429   auto eltType = getType().getElementType();
1430 
1431   // Handle floating point elements.
1432   if (eltType.isa<FloatType>())
1433     return FloatAttr::get(eltType, 0);
1434 
1435   // Otherwise, this is an integer.
1436   // TODO: Handle StringAttr here.
1437   return IntegerAttr::get(eltType, 0);
1438 }
1439 
1440 /// Flatten, and return, all of the sparse indices in this attribute in
1441 /// row-major order.
getFlattenedSparseIndices() const1442 std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1443   std::vector<ptrdiff_t> flatSparseIndices;
1444 
1445   // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1446   // as a 1-D index array.
1447   auto sparseIndices = getIndices();
1448   auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1449   if (sparseIndices.isSplat()) {
1450     SmallVector<uint64_t, 8> indices(getType().getRank(),
1451                                      *sparseIndexValues.begin());
1452     flatSparseIndices.push_back(getFlattenedIndex(indices));
1453     return flatSparseIndices;
1454   }
1455 
1456   // Otherwise, reinterpret each index as an ArrayRef when flattening.
1457   auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1458   size_t rank = getType().getRank();
1459   for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1460     flatSparseIndices.push_back(getFlattenedIndex(
1461         {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1462   return flatSparseIndices;
1463 }
1464 
1465 //===----------------------------------------------------------------------===//
1466 // MutableDictionaryAttr
1467 //===----------------------------------------------------------------------===//
1468 
MutableDictionaryAttr(ArrayRef<NamedAttribute> attributes)1469 MutableDictionaryAttr::MutableDictionaryAttr(
1470     ArrayRef<NamedAttribute> attributes) {
1471   setAttrs(attributes);
1472 }
1473 
1474 /// Return the underlying dictionary attribute.
1475 DictionaryAttr
getDictionary(MLIRContext * context) const1476 MutableDictionaryAttr::getDictionary(MLIRContext *context) const {
1477   // Construct empty DictionaryAttr if needed.
1478   if (!attrs)
1479     return DictionaryAttr::get({}, context);
1480   return attrs;
1481 }
1482 
getAttrs() const1483 ArrayRef<NamedAttribute> MutableDictionaryAttr::getAttrs() const {
1484   return attrs ? attrs.getValue() : llvm::None;
1485 }
1486 
1487 /// Replace the held attributes with ones provided in 'newAttrs'.
setAttrs(ArrayRef<NamedAttribute> attributes)1488 void MutableDictionaryAttr::setAttrs(ArrayRef<NamedAttribute> attributes) {
1489   // Don't create an attribute list if there are no attributes.
1490   if (attributes.empty())
1491     attrs = nullptr;
1492   else
1493     attrs = DictionaryAttr::get(attributes, attributes[0].second.getContext());
1494 }
1495 
1496 /// Return the specified attribute if present, null otherwise.
get(StringRef name) const1497 Attribute MutableDictionaryAttr::get(StringRef name) const {
1498   return attrs ? attrs.get(name) : nullptr;
1499 }
1500 
1501 /// Return the specified attribute if present, null otherwise.
get(Identifier name) const1502 Attribute MutableDictionaryAttr::get(Identifier name) const {
1503   return attrs ? attrs.get(name) : nullptr;
1504 }
1505 
1506 /// Return the specified named attribute if present, None otherwise.
getNamed(StringRef name) const1507 Optional<NamedAttribute> MutableDictionaryAttr::getNamed(StringRef name) const {
1508   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1509 }
1510 Optional<NamedAttribute>
getNamed(Identifier name) const1511 MutableDictionaryAttr::getNamed(Identifier name) const {
1512   return attrs ? attrs.getNamed(name) : Optional<NamedAttribute>();
1513 }
1514 
1515 /// If the an attribute exists with the specified name, change it to the new
1516 /// value.  Otherwise, add a new attribute with the specified name/value.
set(Identifier name,Attribute value)1517 void MutableDictionaryAttr::set(Identifier name, Attribute value) {
1518   assert(value && "attributes may never be null");
1519 
1520   // Look for an existing value for the given name, and set it in-place.
1521   ArrayRef<NamedAttribute> values = getAttrs();
1522   const auto *it = llvm::find_if(
1523       values, [name](NamedAttribute attr) { return attr.first == name; });
1524   if (it != values.end()) {
1525     // Bail out early if the value is the same as what we already have.
1526     if (it->second == value)
1527       return;
1528 
1529     SmallVector<NamedAttribute, 8> newAttrs(values.begin(), values.end());
1530     newAttrs[it - values.begin()].second = value;
1531     attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1532     return;
1533   }
1534 
1535   // Otherwise, insert the new attribute into its sorted position.
1536   it = llvm::lower_bound(values, name);
1537   SmallVector<NamedAttribute, 8> newAttrs;
1538   newAttrs.reserve(values.size() + 1);
1539   newAttrs.append(values.begin(), it);
1540   newAttrs.push_back({name, value});
1541   newAttrs.append(it, values.end());
1542   attrs = DictionaryAttr::getWithSorted(newAttrs, value.getContext());
1543 }
1544 
1545 /// Remove the attribute with the specified name if it exists.  The return
1546 /// value indicates whether the attribute was present or not.
remove(Identifier name)1547 auto MutableDictionaryAttr::remove(Identifier name) -> RemoveResult {
1548   auto origAttrs = getAttrs();
1549   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
1550     if (origAttrs[i].first == name) {
1551       // Handle the simple case of removing the only attribute in the list.
1552       if (e == 1) {
1553         attrs = nullptr;
1554         return RemoveResult::Removed;
1555       }
1556 
1557       SmallVector<NamedAttribute, 8> newAttrs;
1558       newAttrs.reserve(origAttrs.size() - 1);
1559       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
1560       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
1561       attrs = DictionaryAttr::getWithSorted(newAttrs,
1562                                             newAttrs[0].second.getContext());
1563       return RemoveResult::Removed;
1564     }
1565   }
1566   return RemoveResult::NotFound;
1567 }
1568