1 //===- BuiltinAttributes.h - MLIR Builtin Attribute Classes -----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_BUILTINATTRIBUTES_H 10 #define MLIR_IR_BUILTINATTRIBUTES_H 11 12 #include "mlir/IR/Attributes.h" 13 #include "llvm/ADT/APFloat.h" 14 #include "llvm/ADT/Sequence.h" 15 #include <complex> 16 17 namespace mlir { 18 class AffineMap; 19 class FunctionType; 20 class IntegerSet; 21 class Location; 22 class ShapedType; 23 24 namespace detail { 25 26 struct AffineMapAttributeStorage; 27 struct ArrayAttributeStorage; 28 struct DictionaryAttributeStorage; 29 struct IntegerAttributeStorage; 30 struct IntegerSetAttributeStorage; 31 struct FloatAttributeStorage; 32 struct OpaqueAttributeStorage; 33 struct StringAttributeStorage; 34 struct SymbolRefAttributeStorage; 35 struct TypeAttributeStorage; 36 37 /// Elements Attributes. 38 struct DenseIntOrFPElementsAttributeStorage; 39 struct DenseStringElementsAttributeStorage; 40 struct OpaqueElementsAttributeStorage; 41 struct SparseElementsAttributeStorage; 42 } // namespace detail 43 44 //===----------------------------------------------------------------------===// 45 // AffineMapAttr 46 //===----------------------------------------------------------------------===// 47 48 class AffineMapAttr 49 : public Attribute::AttrBase<AffineMapAttr, Attribute, 50 detail::AffineMapAttributeStorage> { 51 public: 52 using Base::Base; 53 using ValueType = AffineMap; 54 55 static AffineMapAttr get(AffineMap value); 56 57 AffineMap getValue() const; 58 }; 59 60 //===----------------------------------------------------------------------===// 61 // ArrayAttr 62 //===----------------------------------------------------------------------===// 63 64 /// Array attributes are lists of other attributes. They are not necessarily 65 /// type homogenous given that attributes don't, in general, carry types. 66 class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute, 67 detail::ArrayAttributeStorage> { 68 public: 69 using Base::Base; 70 using ValueType = ArrayRef<Attribute>; 71 72 static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context); 73 74 ArrayRef<Attribute> getValue() const; 75 Attribute operator[](unsigned idx) const; 76 77 /// Support range iteration. 78 using iterator = llvm::ArrayRef<Attribute>::iterator; begin()79 iterator begin() const { return getValue().begin(); } end()80 iterator end() const { return getValue().end(); } size()81 size_t size() const { return getValue().size(); } empty()82 bool empty() const { return size() == 0; } 83 84 private: 85 /// Class for underlying value iterator support. 86 template <typename AttrTy> 87 class attr_value_iterator final 88 : public llvm::mapped_iterator<ArrayAttr::iterator, 89 AttrTy (*)(Attribute)> { 90 public: attr_value_iterator(ArrayAttr::iterator it)91 explicit attr_value_iterator(ArrayAttr::iterator it) 92 : llvm::mapped_iterator<ArrayAttr::iterator, AttrTy (*)(Attribute)>( 93 it, [](Attribute attr) { return attr.cast<AttrTy>(); }) {} 94 AttrTy operator*() const { return (*this->I).template cast<AttrTy>(); } 95 }; 96 97 public: 98 template <typename AttrTy> getAsRange()99 iterator_range<attr_value_iterator<AttrTy>> getAsRange() { 100 return llvm::make_range(attr_value_iterator<AttrTy>(begin()), 101 attr_value_iterator<AttrTy>(end())); 102 } 103 template <typename AttrTy, typename UnderlyingTy = typename AttrTy::ValueType> getAsValueRange()104 auto getAsValueRange() { 105 return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) { 106 return static_cast<UnderlyingTy>(attr.getValue()); 107 }); 108 } 109 }; 110 111 //===----------------------------------------------------------------------===// 112 // DictionaryAttr 113 //===----------------------------------------------------------------------===// 114 115 /// Dictionary attribute is an attribute that represents a sorted collection of 116 /// named attribute values. The elements are sorted by name, and each name must 117 /// be unique within the collection. 118 class DictionaryAttr 119 : public Attribute::AttrBase<DictionaryAttr, Attribute, 120 detail::DictionaryAttributeStorage> { 121 public: 122 using Base::Base; 123 using ValueType = ArrayRef<NamedAttribute>; 124 125 /// Construct a dictionary attribute with the provided list of named 126 /// attributes. This method assumes that the provided list is unordered. If 127 /// the caller can guarantee that the attributes are ordered by name, 128 /// getWithSorted should be used instead. 129 static DictionaryAttr get(ArrayRef<NamedAttribute> value, 130 MLIRContext *context); 131 132 /// Construct a dictionary with an array of values that is known to already be 133 /// sorted by name and uniqued. 134 static DictionaryAttr getWithSorted(ArrayRef<NamedAttribute> value, 135 MLIRContext *context); 136 137 ArrayRef<NamedAttribute> getValue() const; 138 139 /// Return the specified attribute if present, null otherwise. 140 Attribute get(StringRef name) const; 141 Attribute get(Identifier name) const; 142 143 /// Return the specified named attribute if present, None otherwise. 144 Optional<NamedAttribute> getNamed(StringRef name) const; 145 Optional<NamedAttribute> getNamed(Identifier name) const; 146 147 /// Support range iteration. 148 using iterator = llvm::ArrayRef<NamedAttribute>::iterator; 149 iterator begin() const; 150 iterator end() const; empty()151 bool empty() const { return size() == 0; } 152 size_t size() const; 153 154 /// Sorts the NamedAttributes in the array ordered by name as expected by 155 /// getWithSorted and returns whether the values were sorted. 156 /// Requires: uniquely named attributes. 157 static bool sort(ArrayRef<NamedAttribute> values, 158 SmallVectorImpl<NamedAttribute> &storage); 159 160 /// Sorts the NamedAttributes in the array ordered by name as expected by 161 /// getWithSorted in place on an array and returns whether the values needed 162 /// to be sorted. 163 /// Requires: uniquely named attributes. 164 static bool sortInPlace(SmallVectorImpl<NamedAttribute> &array); 165 166 /// Returns an entry with a duplicate name in `array`, if it exists, else 167 /// returns llvm::None. If `isSorted` is true, the array is assumed to be 168 /// sorted else it will be sorted in place before finding the duplicate entry. 169 static Optional<NamedAttribute> 170 findDuplicate(SmallVectorImpl<NamedAttribute> &array, bool isSorted); 171 172 private: 173 /// Return empty dictionary. 174 static DictionaryAttr getEmpty(MLIRContext *context); 175 }; 176 177 //===----------------------------------------------------------------------===// 178 // FloatAttr 179 //===----------------------------------------------------------------------===// 180 181 class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute, 182 detail::FloatAttributeStorage> { 183 public: 184 using Base::Base; 185 using ValueType = APFloat; 186 187 /// Return a float attribute for the specified value in the specified type. 188 /// These methods should only be used for simple constant values, e.g 1.0/2.0, 189 /// that are known-valid both as host double and the 'type' format. 190 static FloatAttr get(Type type, double value); 191 static FloatAttr getChecked(Type type, double value, Location loc); 192 193 /// Return a float attribute for the specified value in the specified type. 194 static FloatAttr get(Type type, const APFloat &value); 195 static FloatAttr getChecked(Type type, const APFloat &value, Location loc); 196 197 APFloat getValue() const; 198 199 /// This function is used to convert the value to a double, even if it loses 200 /// precision. 201 double getValueAsDouble() const; 202 static double getValueAsDouble(APFloat val); 203 204 /// Verify the construction invariants for a double value. 205 static LogicalResult verifyConstructionInvariants(Location loc, Type type, 206 double value); 207 static LogicalResult verifyConstructionInvariants(Location loc, Type type, 208 const APFloat &value); 209 }; 210 211 //===----------------------------------------------------------------------===// 212 // IntegerAttr 213 //===----------------------------------------------------------------------===// 214 215 class IntegerAttr 216 : public Attribute::AttrBase<IntegerAttr, Attribute, 217 detail::IntegerAttributeStorage> { 218 public: 219 using Base::Base; 220 using ValueType = APInt; 221 222 static IntegerAttr get(Type type, int64_t value); 223 static IntegerAttr get(Type type, const APInt &value); 224 225 APInt getValue() const; 226 /// Return the integer value as a 64-bit int. The attribute must be a signless 227 /// integer. 228 // TODO: Change callers to use getValue instead. 229 int64_t getInt() const; 230 /// Return the integer value as a signed 64-bit int. The attribute must be 231 /// a signed integer. 232 int64_t getSInt() const; 233 /// Return the integer value as a unsigned 64-bit int. The attribute must be 234 /// an unsigned integer. 235 uint64_t getUInt() const; 236 237 static LogicalResult verifyConstructionInvariants(Location loc, Type type, 238 int64_t value); 239 static LogicalResult verifyConstructionInvariants(Location loc, Type type, 240 const APInt &value); 241 }; 242 243 //===----------------------------------------------------------------------===// 244 // BoolAttr 245 246 /// Special case of IntegerAttr to represent boolean integers, i.e., signless i1 247 /// integers. 248 class BoolAttr : public Attribute { 249 public: 250 using Attribute::Attribute; 251 using ValueType = bool; 252 253 static BoolAttr get(bool value, MLIRContext *context); 254 255 /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to 256 /// avoid bringing in all of IntegerAttrs methods. IntegerAttr()257 operator IntegerAttr() const { return IntegerAttr(impl); } 258 259 /// Return the boolean value of this attribute. 260 bool getValue() const; 261 262 /// Methods for support type inquiry through isa, cast, and dyn_cast. 263 static bool classof(Attribute attr); 264 }; 265 266 //===----------------------------------------------------------------------===// 267 // IntegerSetAttr 268 //===----------------------------------------------------------------------===// 269 270 class IntegerSetAttr 271 : public Attribute::AttrBase<IntegerSetAttr, Attribute, 272 detail::IntegerSetAttributeStorage> { 273 public: 274 using Base::Base; 275 using ValueType = IntegerSet; 276 277 static IntegerSetAttr get(IntegerSet value); 278 279 IntegerSet getValue() const; 280 }; 281 282 //===----------------------------------------------------------------------===// 283 // OpaqueAttr 284 //===----------------------------------------------------------------------===// 285 286 /// Opaque attributes represent attributes of non-registered dialects. These are 287 /// attribute represented in their raw string form, and can only usefully be 288 /// tested for attribute equality. 289 class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute, 290 detail::OpaqueAttributeStorage> { 291 public: 292 using Base::Base; 293 294 /// Get or create a new OpaqueAttr with the provided dialect and string data. 295 static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type, 296 MLIRContext *context); 297 298 /// Get or create a new OpaqueAttr with the provided dialect and string data. 299 /// If the given identifier is not a valid namespace for a dialect, then a 300 /// null attribute is returned. 301 static OpaqueAttr getChecked(Identifier dialect, StringRef attrData, 302 Type type, Location location); 303 304 /// Returns the dialect namespace of the opaque attribute. 305 Identifier getDialectNamespace() const; 306 307 /// Returns the raw attribute data of the opaque attribute. 308 StringRef getAttrData() const; 309 310 /// Verify the construction of an opaque attribute. 311 static LogicalResult verifyConstructionInvariants(Location loc, 312 Identifier dialect, 313 StringRef attrData, 314 Type type); 315 }; 316 317 //===----------------------------------------------------------------------===// 318 // StringAttr 319 //===----------------------------------------------------------------------===// 320 321 class StringAttr : public Attribute::AttrBase<StringAttr, Attribute, 322 detail::StringAttributeStorage> { 323 public: 324 using Base::Base; 325 using ValueType = StringRef; 326 327 /// Get an instance of a StringAttr with the given string. 328 static StringAttr get(StringRef bytes, MLIRContext *context); 329 330 /// Get an instance of a StringAttr with the given string and Type. 331 static StringAttr get(StringRef bytes, Type type); 332 333 StringRef getValue() const; 334 }; 335 336 //===----------------------------------------------------------------------===// 337 // SymbolRefAttr 338 //===----------------------------------------------------------------------===// 339 340 class FlatSymbolRefAttr; 341 342 /// A symbol reference attribute represents a symbolic reference to another 343 /// operation. 344 class SymbolRefAttr 345 : public Attribute::AttrBase<SymbolRefAttr, Attribute, 346 detail::SymbolRefAttributeStorage> { 347 public: 348 using Base::Base; 349 350 /// Construct a symbol reference for the given value name. 351 static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx); 352 353 /// Construct a symbol reference for the given value name, and a set of nested 354 /// references that are further resolve to a nested symbol. 355 static SymbolRefAttr get(StringRef value, 356 ArrayRef<FlatSymbolRefAttr> references, 357 MLIRContext *ctx); 358 359 /// Returns the name of the top level symbol reference, i.e. the root of the 360 /// reference path. 361 StringRef getRootReference() const; 362 363 /// Returns the name of the fully resolved symbol, i.e. the leaf of the 364 /// reference path. 365 StringRef getLeafReference() const; 366 367 /// Returns the set of nested references representing the path to the symbol 368 /// nested under the root reference. 369 ArrayRef<FlatSymbolRefAttr> getNestedReferences() const; 370 }; 371 372 /// A symbol reference with a reference path containing a single element. This 373 /// is used to refer to an operation within the current symbol table. 374 class FlatSymbolRefAttr : public SymbolRefAttr { 375 public: 376 using SymbolRefAttr::SymbolRefAttr; 377 using ValueType = StringRef; 378 379 /// Construct a symbol reference for the given value name. get(StringRef value,MLIRContext * ctx)380 static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) { 381 return SymbolRefAttr::get(value, ctx); 382 } 383 384 /// Returns the name of the held symbol reference. getValue()385 StringRef getValue() const { return getRootReference(); } 386 387 /// Methods for support type inquiry through isa, cast, and dyn_cast. classof(Attribute attr)388 static bool classof(Attribute attr) { 389 SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>(); 390 return refAttr && refAttr.getNestedReferences().empty(); 391 } 392 393 private: 394 using SymbolRefAttr::get; 395 using SymbolRefAttr::getNestedReferences; 396 }; 397 398 //===----------------------------------------------------------------------===// 399 // Type 400 //===----------------------------------------------------------------------===// 401 402 class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute, 403 detail::TypeAttributeStorage> { 404 public: 405 using Base::Base; 406 using ValueType = Type; 407 408 static TypeAttr get(Type value); 409 410 Type getValue() const; 411 }; 412 413 //===----------------------------------------------------------------------===// 414 // UnitAttr 415 //===----------------------------------------------------------------------===// 416 417 /// Unit attributes are attributes that hold no specific value and are given 418 /// meaning by their existence. 419 class UnitAttr 420 : public Attribute::AttrBase<UnitAttr, Attribute, AttributeStorage> { 421 public: 422 using Base::Base; 423 424 static UnitAttr get(MLIRContext *context); 425 }; 426 427 //===----------------------------------------------------------------------===// 428 // Elements Attributes 429 //===----------------------------------------------------------------------===// 430 431 namespace detail { 432 template <typename T> 433 class ElementsAttrIterator; 434 template <typename T> 435 class ElementsAttrRange; 436 } // namespace detail 437 438 /// A base attribute that represents a reference to a static shaped tensor or 439 /// vector constant. 440 class ElementsAttr : public Attribute { 441 public: 442 using Attribute::Attribute; 443 template <typename T> 444 using iterator = detail::ElementsAttrIterator<T>; 445 template <typename T> 446 using iterator_range = detail::ElementsAttrRange<T>; 447 448 /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor 449 /// with static shape. 450 ShapedType getType() const; 451 452 /// Return the value at the given index. The index is expected to refer to a 453 /// valid element. 454 Attribute getValue(ArrayRef<uint64_t> index) const; 455 456 /// Return the value of type 'T' at the given index, where 'T' corresponds to 457 /// an Attribute type. 458 template <typename T> getValue(ArrayRef<uint64_t> index)459 T getValue(ArrayRef<uint64_t> index) const { 460 return getValue(index).template cast<T>(); 461 } 462 463 /// Return the elements of this attribute as a value of type 'T'. Note: 464 /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support 465 /// iteration. 466 template <typename T> 467 iterator_range<T> getValues() const; 468 469 /// Return if the given 'index' refers to a valid element in this attribute. 470 bool isValidIndex(ArrayRef<uint64_t> index) const; 471 472 /// Returns the number of elements held by this attribute. 473 int64_t getNumElements() const; 474 475 /// Returns the number of elements held by this attribute. size()476 int64_t size() const { return getNumElements(); } 477 478 /// Generates a new ElementsAttr by mapping each int value to a new 479 /// underlying APInt. The new values can represent either an integer or float. 480 /// This ElementsAttr should contain integers. 481 ElementsAttr mapValues(Type newElementType, 482 function_ref<APInt(const APInt &)> mapping) const; 483 484 /// Generates a new ElementsAttr by mapping each float value to a new 485 /// underlying APInt. The new values can represent either an integer or float. 486 /// This ElementsAttr should contain floats. 487 ElementsAttr mapValues(Type newElementType, 488 function_ref<APInt(const APFloat &)> mapping) const; 489 490 /// Method for support type inquiry through isa, cast and dyn_cast. 491 static bool classof(Attribute attr); 492 493 protected: 494 /// Returns the 1 dimensional flattened row-major index from the given 495 /// multi-dimensional index. 496 uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const; 497 }; 498 499 namespace detail { 500 /// DenseElementsAttr data is aligned to uint64_t, so this traits class is 501 /// necessary to interop with PointerIntPair. 502 class DenseElementDataPointerTypeTraits { 503 public: getAsVoidPointer(const char * ptr)504 static inline const void *getAsVoidPointer(const char *ptr) { return ptr; } getFromVoidPointer(const void * ptr)505 static inline const char *getFromVoidPointer(const void *ptr) { 506 return static_cast<const char *>(ptr); 507 } 508 509 // Note: We could steal more bits if the need arises. 510 static constexpr int NumLowBitsAvailable = 1; 511 }; 512 513 /// Pair of raw pointer and a boolean flag of whether the pointer holds a splat, 514 using DenseIterPtrAndSplat = 515 llvm::PointerIntPair<const char *, 1, bool, 516 DenseElementDataPointerTypeTraits>; 517 518 /// Impl iterator for indexed DenseElementsAttr iterators that records a data 519 /// pointer and data index that is adjusted for the case of a splat attribute. 520 template <typename ConcreteT, typename T, typename PointerT = T *, 521 typename ReferenceT = T &> 522 class DenseElementIndexedIteratorImpl 523 : public llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, 524 PointerT, ReferenceT> { 525 protected: DenseElementIndexedIteratorImpl(const char * data,bool isSplat,size_t dataIndex)526 DenseElementIndexedIteratorImpl(const char *data, bool isSplat, 527 size_t dataIndex) 528 : llvm::indexed_accessor_iterator<ConcreteT, DenseIterPtrAndSplat, T, 529 PointerT, ReferenceT>({data, isSplat}, 530 dataIndex) {} 531 532 /// Return the current index for this iterator, adjusted for the case of a 533 /// splat. getDataIndex()534 ptrdiff_t getDataIndex() const { 535 bool isSplat = this->base.getInt(); 536 return isSplat ? 0 : this->index; 537 } 538 539 /// Return the data base pointer. getData()540 const char *getData() const { return this->base.getPointer(); } 541 }; 542 543 /// Type trait detector that checks if a given type T is a complex type. 544 template <typename T> 545 struct is_complex_t : public std::false_type {}; 546 template <typename T> 547 struct is_complex_t<std::complex<T>> : public std::true_type {}; 548 } // namespace detail 549 550 /// An attribute that represents a reference to a dense vector or tensor object. 551 /// 552 class DenseElementsAttr : public ElementsAttr { 553 public: 554 using ElementsAttr::ElementsAttr; 555 556 /// Type trait used to check if the given type T is a potentially valid C++ 557 /// floating point type that can be used to access the underlying element 558 /// types of a DenseElementsAttr. 559 // TODO: Use std::disjunction when C++17 is supported. 560 template <typename T> 561 struct is_valid_cpp_fp_type { 562 /// The type is a valid floating point type if it is a builtin floating 563 /// point type, or is a potentially user defined floating point type. The 564 /// latter allows for supporting users that have custom types defined for 565 /// bfloat16/half/etc. 566 static constexpr bool value = llvm::is_one_of<T, float, double>::value || 567 (std::numeric_limits<T>::is_specialized && 568 !std::numeric_limits<T>::is_integer); 569 }; 570 571 /// Method for support type inquiry through isa, cast and dyn_cast. 572 static bool classof(Attribute attr); 573 574 /// Constructs a dense elements attribute from an array of element values. 575 /// Each element attribute value is expected to be an element of 'type'. 576 /// 'type' must be a vector or tensor with static shape. If the element of 577 /// `type` is non-integer/index/float it is assumed to be a string type. 578 static DenseElementsAttr get(ShapedType type, ArrayRef<Attribute> values); 579 580 /// Constructs a dense integer elements attribute from an array of integer 581 /// or floating-point values. Each value is expected to be the same bitwidth 582 /// of the element type of 'type'. 'type' must be a vector or tensor with 583 /// static shape. 584 template <typename T, typename = typename std::enable_if< 585 std::numeric_limits<T>::is_integer || 586 is_valid_cpp_fp_type<T>::value>::type> 587 static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) { 588 const char *data = reinterpret_cast<const char *>(values.data()); 589 return getRawIntOrFloat( 590 type, ArrayRef<char>(data, values.size() * sizeof(T)), sizeof(T), 591 std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed); 592 } 593 594 /// Constructs a dense integer elements attribute from a single element. 595 template <typename T, typename = typename std::enable_if< 596 std::numeric_limits<T>::is_integer || 597 is_valid_cpp_fp_type<T>::value || 598 detail::is_complex_t<T>::value>::type> 599 static DenseElementsAttr get(const ShapedType &type, T value) { 600 return get(type, llvm::makeArrayRef(value)); 601 } 602 603 /// Constructs a dense complex elements attribute from an array of complex 604 /// values. Each value is expected to be the same bitwidth of the element type 605 /// of 'type'. 'type' must be a vector or tensor with static shape. 606 template <typename T, typename ElementT = typename T::value_type, 607 typename = typename std::enable_if< 608 detail::is_complex_t<T>::value && 609 (std::numeric_limits<ElementT>::is_integer || 610 is_valid_cpp_fp_type<ElementT>::value)>::type> 611 static DenseElementsAttr get(const ShapedType &type, ArrayRef<T> values) { 612 const char *data = reinterpret_cast<const char *>(values.data()); 613 return getRawComplex(type, ArrayRef<char>(data, values.size() * sizeof(T)), 614 sizeof(T), std::numeric_limits<ElementT>::is_integer, 615 std::numeric_limits<ElementT>::is_signed); 616 } 617 618 /// Overload of the above 'get' method that is specialized for boolean values. 619 static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values); 620 621 /// Overload of the above 'get' method that is specialized for StringRef 622 /// values. 623 static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values); 624 625 /// Constructs a dense integer elements attribute from an array of APInt 626 /// values. Each APInt value is expected to have the same bitwidth as the 627 /// element type of 'type'. 'type' must be a vector or tensor with static 628 /// shape. 629 static DenseElementsAttr get(ShapedType type, ArrayRef<APInt> values); 630 631 /// Constructs a dense complex elements attribute from an array of APInt 632 /// values. Each APInt value is expected to have the same bitwidth as the 633 /// element type of 'type'. 'type' must be a vector or tensor with static 634 /// shape. 635 static DenseElementsAttr get(ShapedType type, 636 ArrayRef<std::complex<APInt>> values); 637 638 /// Constructs a dense float elements attribute from an array of APFloat 639 /// values. Each APFloat value is expected to have the same bitwidth as the 640 /// element type of 'type'. 'type' must be a vector or tensor with static 641 /// shape. 642 static DenseElementsAttr get(ShapedType type, ArrayRef<APFloat> values); 643 644 /// Constructs a dense complex elements attribute from an array of APFloat 645 /// values. Each APFloat value is expected to have the same bitwidth as the 646 /// element type of 'type'. 'type' must be a vector or tensor with static 647 /// shape. 648 static DenseElementsAttr get(ShapedType type, 649 ArrayRef<std::complex<APFloat>> values); 650 651 /// Construct a dense elements attribute for an initializer_list of values. 652 /// Each value is expected to be the same bitwidth of the element type of 653 /// 'type'. 'type' must be a vector or tensor with static shape. 654 template <typename T> 655 static DenseElementsAttr get(const ShapedType &type, 656 const std::initializer_list<T> &list) { 657 return get(type, ArrayRef<T>(list)); 658 } 659 660 /// Construct a dense elements attribute from a raw buffer representing the 661 /// data for this attribute. Users should generally not use this methods as 662 /// the expected buffer format may not be a form the user expects. 663 static DenseElementsAttr getFromRawBuffer(ShapedType type, 664 ArrayRef<char> rawBuffer, 665 bool isSplatBuffer); 666 667 /// Returns true if the given buffer is a valid raw buffer for the given type. 668 /// `detectedSplat` is set if the buffer is valid and represents a splat 669 /// buffer. 670 static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer, 671 bool &detectedSplat); 672 673 //===--------------------------------------------------------------------===// 674 // Iterators 675 //===--------------------------------------------------------------------===// 676 677 /// A utility iterator that allows walking over the internal Attribute values 678 /// of a DenseElementsAttr. 679 class AttributeElementIterator 680 : public llvm::indexed_accessor_iterator<AttributeElementIterator, 681 const void *, Attribute, 682 Attribute, Attribute> { 683 public: 684 /// Accesses the Attribute value at this iterator position. 685 Attribute operator*() const; 686 687 private: 688 friend DenseElementsAttr; 689 690 /// Constructs a new iterator. 691 AttributeElementIterator(DenseElementsAttr attr, size_t index); 692 }; 693 694 /// Iterator for walking raw element values of the specified type 'T', which 695 /// may be any c++ data type matching the stored representation: int32_t, 696 /// float, etc. 697 template <typename T> 698 class ElementIterator 699 : public detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, 700 const T> { 701 public: 702 /// Accesses the raw value at this iterator position. 703 const T &operator*() const { 704 return reinterpret_cast<const T *>(this->getData())[this->getDataIndex()]; 705 } 706 707 private: 708 friend DenseElementsAttr; 709 710 /// Constructs a new iterator. 711 ElementIterator(const char *data, bool isSplat, size_t dataIndex) 712 : detail::DenseElementIndexedIteratorImpl<ElementIterator<T>, const T>( 713 data, isSplat, dataIndex) {} 714 }; 715 716 /// A utility iterator that allows walking over the internal bool values. 717 class BoolElementIterator 718 : public detail::DenseElementIndexedIteratorImpl<BoolElementIterator, 719 bool, bool, bool> { 720 public: 721 /// Accesses the bool value at this iterator position. 722 bool operator*() const; 723 724 private: 725 friend DenseElementsAttr; 726 727 /// Constructs a new iterator. 728 BoolElementIterator(DenseElementsAttr attr, size_t dataIndex); 729 }; 730 731 /// A utility iterator that allows walking over the internal raw APInt values. 732 class IntElementIterator 733 : public detail::DenseElementIndexedIteratorImpl<IntElementIterator, 734 APInt, APInt, APInt> { 735 public: 736 /// Accesses the raw APInt value at this iterator position. 737 APInt operator*() const; 738 739 private: 740 friend DenseElementsAttr; 741 742 /// Constructs a new iterator. 743 IntElementIterator(DenseElementsAttr attr, size_t dataIndex); 744 745 /// The bitwidth of the element type. 746 size_t bitWidth; 747 }; 748 749 /// A utility iterator that allows walking over the internal raw complex APInt 750 /// values. 751 class ComplexIntElementIterator 752 : public detail::DenseElementIndexedIteratorImpl< 753 ComplexIntElementIterator, std::complex<APInt>, std::complex<APInt>, 754 std::complex<APInt>> { 755 public: 756 /// Accesses the raw std::complex<APInt> value at this iterator position. 757 std::complex<APInt> operator*() const; 758 759 private: 760 friend DenseElementsAttr; 761 762 /// Constructs a new iterator. 763 ComplexIntElementIterator(DenseElementsAttr attr, size_t dataIndex); 764 765 /// The bitwidth of the element type. 766 size_t bitWidth; 767 }; 768 769 /// Iterator for walking over APFloat values. 770 class FloatElementIterator final 771 : public llvm::mapped_iterator<IntElementIterator, 772 std::function<APFloat(const APInt &)>> { 773 friend DenseElementsAttr; 774 775 /// Initializes the float element iterator to the specified iterator. 776 FloatElementIterator(const llvm::fltSemantics &smt, IntElementIterator it); 777 778 public: 779 using reference = APFloat; 780 }; 781 782 /// Iterator for walking over complex APFloat values. 783 class ComplexFloatElementIterator final 784 : public llvm::mapped_iterator< 785 ComplexIntElementIterator, 786 std::function<std::complex<APFloat>(const std::complex<APInt> &)>> { 787 friend DenseElementsAttr; 788 789 /// Initializes the float element iterator to the specified iterator. 790 ComplexFloatElementIterator(const llvm::fltSemantics &smt, 791 ComplexIntElementIterator it); 792 793 public: 794 using reference = std::complex<APFloat>; 795 }; 796 797 //===--------------------------------------------------------------------===// 798 // Value Querying 799 //===--------------------------------------------------------------------===// 800 801 /// Returns true if this attribute corresponds to a splat, i.e. if all element 802 /// values are the same. 803 bool isSplat() const; 804 805 /// Return the splat value for this attribute. This asserts that the attribute 806 /// corresponds to a splat. 807 Attribute getSplatValue() const { return getSplatValue<Attribute>(); } 808 template <typename T> 809 typename std::enable_if<!std::is_base_of<Attribute, T>::value || 810 std::is_same<Attribute, T>::value, 811 T>::type 812 getSplatValue() const { 813 assert(isSplat() && "expected the attribute to be a splat"); 814 return *getValues<T>().begin(); 815 } 816 /// Return the splat value for derived attribute element types. 817 template <typename T> 818 typename std::enable_if<std::is_base_of<Attribute, T>::value && 819 !std::is_same<Attribute, T>::value, 820 T>::type 821 getSplatValue() const { 822 return getSplatValue().template cast<T>(); 823 } 824 825 /// Return the value at the given index. The 'index' is expected to refer to a 826 /// valid element. 827 Attribute getValue(ArrayRef<uint64_t> index) const { 828 return getValue<Attribute>(index); 829 } 830 template <typename T> 831 T getValue(ArrayRef<uint64_t> index) const { 832 // Skip to the element corresponding to the flattened index. 833 return *std::next(getValues<T>().begin(), getFlattenedIndex(index)); 834 } 835 836 /// Return the held element values as a range of integer or floating-point 837 /// values. 838 template <typename T, typename = typename std::enable_if< 839 (!std::is_same<T, bool>::value && 840 std::numeric_limits<T>::is_integer) || 841 is_valid_cpp_fp_type<T>::value>::type> 842 llvm::iterator_range<ElementIterator<T>> getValues() const { 843 assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer, 844 std::numeric_limits<T>::is_signed)); 845 const char *rawData = getRawData().data(); 846 bool splat = isSplat(); 847 return {ElementIterator<T>(rawData, splat, 0), 848 ElementIterator<T>(rawData, splat, getNumElements())}; 849 } 850 851 /// Return the held element values as a range of std::complex. 852 template <typename T, typename ElementT = typename T::value_type, 853 typename = typename std::enable_if< 854 detail::is_complex_t<T>::value && 855 (std::numeric_limits<ElementT>::is_integer || 856 is_valid_cpp_fp_type<ElementT>::value)>::type> 857 llvm::iterator_range<ElementIterator<T>> getValues() const { 858 assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer, 859 std::numeric_limits<ElementT>::is_signed)); 860 const char *rawData = getRawData().data(); 861 bool splat = isSplat(); 862 return {ElementIterator<T>(rawData, splat, 0), 863 ElementIterator<T>(rawData, splat, getNumElements())}; 864 } 865 866 /// Return the held element values as a range of StringRef. 867 template <typename T, typename = typename std::enable_if< 868 std::is_same<T, StringRef>::value>::type> 869 llvm::iterator_range<ElementIterator<StringRef>> getValues() const { 870 auto stringRefs = getRawStringData(); 871 const char *ptr = reinterpret_cast<const char *>(stringRefs.data()); 872 bool splat = isSplat(); 873 return {ElementIterator<StringRef>(ptr, splat, 0), 874 ElementIterator<StringRef>(ptr, splat, getNumElements())}; 875 } 876 877 /// Return the held element values as a range of Attributes. 878 llvm::iterator_range<AttributeElementIterator> getAttributeValues() const; 879 template <typename T, typename = typename std::enable_if< 880 std::is_same<T, Attribute>::value>::type> 881 llvm::iterator_range<AttributeElementIterator> getValues() const { 882 return getAttributeValues(); 883 } 884 AttributeElementIterator attr_value_begin() const; 885 AttributeElementIterator attr_value_end() const; 886 887 /// Return the held element values a range of T, where T is a derived 888 /// attribute type. 889 template <typename T> 890 using DerivedAttributeElementIterator = 891 llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>; 892 template <typename T, typename = typename std::enable_if< 893 std::is_base_of<Attribute, T>::value && 894 !std::is_same<Attribute, T>::value>::type> 895 llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const { 896 auto castFn = [](Attribute attr) { return attr.template cast<T>(); }; 897 return llvm::map_range(getAttributeValues(), 898 static_cast<T (*)(Attribute)>(castFn)); 899 } 900 901 /// Return the held element values as a range of bool. The element type of 902 /// this attribute must be of integer type of bitwidth 1. 903 llvm::iterator_range<BoolElementIterator> getBoolValues() const; 904 template <typename T, typename = typename std::enable_if< 905 std::is_same<T, bool>::value>::type> 906 llvm::iterator_range<BoolElementIterator> getValues() const { 907 return getBoolValues(); 908 } 909 910 /// Return the held element values as a range of APInts. The element type of 911 /// this attribute must be of integer type. 912 llvm::iterator_range<IntElementIterator> getIntValues() const; 913 template <typename T, typename = typename std::enable_if< 914 std::is_same<T, APInt>::value>::type> 915 llvm::iterator_range<IntElementIterator> getValues() const { 916 return getIntValues(); 917 } 918 IntElementIterator int_value_begin() const; 919 IntElementIterator int_value_end() const; 920 921 /// Return the held element values as a range of complex APInts. The element 922 /// type of this attribute must be a complex of integer type. 923 llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const; 924 template <typename T, typename = typename std::enable_if< 925 std::is_same<T, std::complex<APInt>>::value>::type> 926 llvm::iterator_range<ComplexIntElementIterator> getValues() const { 927 return getComplexIntValues(); 928 } 929 930 /// Return the held element values as a range of APFloat. The element type of 931 /// this attribute must be of float type. 932 llvm::iterator_range<FloatElementIterator> getFloatValues() const; 933 template <typename T, typename = typename std::enable_if< 934 std::is_same<T, APFloat>::value>::type> 935 llvm::iterator_range<FloatElementIterator> getValues() const { 936 return getFloatValues(); 937 } 938 FloatElementIterator float_value_begin() const; 939 FloatElementIterator float_value_end() const; 940 941 /// Return the held element values as a range of complex APFloat. The element 942 /// type of this attribute must be a complex of float type. 943 llvm::iterator_range<ComplexFloatElementIterator> 944 getComplexFloatValues() const; 945 template <typename T, typename = typename std::enable_if<std::is_same< 946 T, std::complex<APFloat>>::value>::type> 947 llvm::iterator_range<ComplexFloatElementIterator> getValues() const { 948 return getComplexFloatValues(); 949 } 950 951 /// Return the raw storage data held by this attribute. Users should generally 952 /// not use this directly, as the internal storage format is not always in the 953 /// form the user might expect. 954 ArrayRef<char> getRawData() const; 955 956 /// Return the raw StringRef data held by this attribute. 957 ArrayRef<StringRef> getRawStringData() const; 958 959 //===--------------------------------------------------------------------===// 960 // Mutation Utilities 961 //===--------------------------------------------------------------------===// 962 963 /// Return a new DenseElementsAttr that has the same data as the current 964 /// attribute, but has been reshaped to 'newType'. The new type must have the 965 /// same total number of elements as well as element type. 966 DenseElementsAttr reshape(ShapedType newType); 967 968 /// Generates a new DenseElementsAttr by mapping each int value to a new 969 /// underlying APInt. The new values can represent either an integer or float. 970 /// This underlying type must be an DenseIntElementsAttr. 971 DenseElementsAttr mapValues(Type newElementType, 972 function_ref<APInt(const APInt &)> mapping) const; 973 974 /// Generates a new DenseElementsAttr by mapping each float value to a new 975 /// underlying APInt. the new values can represent either an integer or float. 976 /// This underlying type must be an DenseFPElementsAttr. 977 DenseElementsAttr 978 mapValues(Type newElementType, 979 function_ref<APInt(const APFloat &)> mapping) const; 980 981 protected: 982 /// Get iterators to the raw APInt values for each element in this attribute. 983 IntElementIterator raw_int_begin() const { 984 return IntElementIterator(*this, 0); 985 } 986 IntElementIterator raw_int_end() const { 987 return IntElementIterator(*this, getNumElements()); 988 } 989 990 /// Overload of the raw 'get' method that asserts that the given type is of 991 /// complex type. This method is used to verify type invariants that the 992 /// templatized 'get' method cannot. 993 static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data, 994 int64_t dataEltSize, bool isInt, 995 bool isSigned); 996 997 /// Overload of the raw 'get' method that asserts that the given type is of 998 /// integer or floating-point type. This method is used to verify type 999 /// invariants that the templatized 'get' method cannot. 1000 static DenseElementsAttr getRawIntOrFloat(ShapedType type, 1001 ArrayRef<char> data, 1002 int64_t dataEltSize, bool isInt, 1003 bool isSigned); 1004 1005 /// Check the information for a C++ data type, check if this type is valid for 1006 /// the current attribute. This method is used to verify specific type 1007 /// invariants that the templatized 'getValues' method cannot. 1008 bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; 1009 1010 /// Check the information for a C++ data type, check if this type is valid for 1011 /// the current attribute. This method is used to verify specific type 1012 /// invariants that the templatized 'getValues' method cannot. 1013 bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const; 1014 }; 1015 1016 /// An attribute class for representing dense arrays of strings. The structure 1017 /// storing and querying a list of densely packed strings. 1018 class DenseStringElementsAttr 1019 : public Attribute::AttrBase<DenseStringElementsAttr, DenseElementsAttr, 1020 detail::DenseStringElementsAttributeStorage> { 1021 1022 public: 1023 using Base::Base; 1024 1025 /// Overload of the raw 'get' method that asserts that the given type is of 1026 /// integer or floating-point type. This method is used to verify type 1027 /// invariants that the templatized 'get' method cannot. 1028 static DenseStringElementsAttr get(ShapedType type, ArrayRef<StringRef> data); 1029 1030 protected: 1031 friend DenseElementsAttr; 1032 }; 1033 1034 /// An attribute class for specializing behavior of Int and Floating-point 1035 /// densely packed string arrays. 1036 class DenseIntOrFPElementsAttr 1037 : public Attribute::AttrBase<DenseIntOrFPElementsAttr, DenseElementsAttr, 1038 detail::DenseIntOrFPElementsAttributeStorage> { 1039 1040 public: 1041 using Base::Base; 1042 1043 /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of 1044 /// the elements of `inRawData` has `type`. If `inRawData` is little endian 1045 /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is 1046 /// BE, converted to LE. 1047 static void 1048 convertEndianOfArrayRefForBEmachine(ArrayRef<char> inRawData, 1049 MutableArrayRef<char> outRawData, 1050 ShapedType type); 1051 1052 /// Convert endianess of input for big-endian(BE) machines. The number of 1053 /// elements of `inRawData` is `numElements`, and each element has 1054 /// `elementBitWidth` bits. If `inRawData` is little endian (LE), it is 1055 /// converted to big endian (BE) and saved in `outRawData`. Conversely, if 1056 /// `inRawData` is BE, converted to LE. 1057 static void convertEndianOfCharForBEmachine(const char *inRawData, 1058 char *outRawData, 1059 size_t elementBitWidth, 1060 size_t numElements); 1061 1062 protected: 1063 friend DenseElementsAttr; 1064 1065 /// Constructs a dense elements attribute from an array of raw APFloat values. 1066 /// Each APFloat value is expected to have the same bitwidth as the element 1067 /// type of 'type'. 'type' must be a vector or tensor with static shape. 1068 static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, 1069 ArrayRef<APFloat> values, bool isSplat); 1070 1071 /// Constructs a dense elements attribute from an array of raw APInt values. 1072 /// Each APInt value is expected to have the same bitwidth as the element type 1073 /// of 'type'. 'type' must be a vector or tensor with static shape. 1074 static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, 1075 ArrayRef<APInt> values, bool isSplat); 1076 1077 /// Get or create a new dense elements attribute instance with the given raw 1078 /// data buffer. 'type' must be a vector or tensor with static shape. 1079 static DenseElementsAttr getRaw(ShapedType type, ArrayRef<char> data, 1080 bool isSplat); 1081 1082 /// Overload of the raw 'get' method that asserts that the given type is of 1083 /// complex type. This method is used to verify type invariants that the 1084 /// templatized 'get' method cannot. 1085 static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef<char> data, 1086 int64_t dataEltSize, bool isInt, 1087 bool isSigned); 1088 1089 /// Overload of the raw 'get' method that asserts that the given type is of 1090 /// integer or floating-point type. This method is used to verify type 1091 /// invariants that the templatized 'get' method cannot. 1092 static DenseElementsAttr getRawIntOrFloat(ShapedType type, 1093 ArrayRef<char> data, 1094 int64_t dataEltSize, bool isInt, 1095 bool isSigned); 1096 }; 1097 1098 /// An attribute that represents a reference to a dense float vector or tensor 1099 /// object. Each element is stored as a double. 1100 class DenseFPElementsAttr : public DenseIntOrFPElementsAttr { 1101 public: 1102 using iterator = DenseElementsAttr::FloatElementIterator; 1103 1104 using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; 1105 1106 /// Get an instance of a DenseFPElementsAttr with the given arguments. This 1107 /// simply wraps the DenseElementsAttr::get calls. 1108 template <typename Arg> 1109 static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) { 1110 return DenseElementsAttr::get(type, llvm::makeArrayRef(arg)) 1111 .template cast<DenseFPElementsAttr>(); 1112 } 1113 template <typename T> 1114 static DenseFPElementsAttr get(const ShapedType &type, 1115 const std::initializer_list<T> &list) { 1116 return DenseElementsAttr::get(type, list) 1117 .template cast<DenseFPElementsAttr>(); 1118 } 1119 1120 /// Generates a new DenseElementsAttr by mapping each value attribute, and 1121 /// constructing the DenseElementsAttr given the new element type. 1122 DenseElementsAttr 1123 mapValues(Type newElementType, 1124 function_ref<APInt(const APFloat &)> mapping) const; 1125 1126 /// Iterator access to the float element values. 1127 iterator begin() const { return float_value_begin(); } 1128 iterator end() const { return float_value_end(); } 1129 1130 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1131 static bool classof(Attribute attr); 1132 }; 1133 1134 /// An attribute that represents a reference to a dense integer vector or tensor 1135 /// object. 1136 class DenseIntElementsAttr : public DenseIntOrFPElementsAttr { 1137 public: 1138 /// DenseIntElementsAttr iterates on APInt, so we can use the raw element 1139 /// iterator directly. 1140 using iterator = DenseElementsAttr::IntElementIterator; 1141 1142 using DenseIntOrFPElementsAttr::DenseIntOrFPElementsAttr; 1143 1144 /// Get an instance of a DenseIntElementsAttr with the given arguments. This 1145 /// simply wraps the DenseElementsAttr::get calls. 1146 template <typename Arg> 1147 static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) { 1148 return DenseElementsAttr::get(type, llvm::makeArrayRef(arg)) 1149 .template cast<DenseIntElementsAttr>(); 1150 } 1151 template <typename T> 1152 static DenseIntElementsAttr get(const ShapedType &type, 1153 const std::initializer_list<T> &list) { 1154 return DenseElementsAttr::get(type, list) 1155 .template cast<DenseIntElementsAttr>(); 1156 } 1157 1158 /// Generates a new DenseElementsAttr by mapping each value attribute, and 1159 /// constructing the DenseElementsAttr given the new element type. 1160 DenseElementsAttr mapValues(Type newElementType, 1161 function_ref<APInt(const APInt &)> mapping) const; 1162 1163 /// Iterator access to the integer element values. 1164 iterator begin() const { return raw_int_begin(); } 1165 iterator end() const { return raw_int_end(); } 1166 1167 /// Method for supporting type inquiry through isa, cast and dyn_cast. 1168 static bool classof(Attribute attr); 1169 }; 1170 1171 /// An opaque attribute that represents a reference to a vector or tensor 1172 /// constant with opaque content. This representation is for tensor constants 1173 /// which the compiler may not need to interpret. This attribute is always 1174 /// associated with a particular dialect, which provides a method to convert 1175 /// tensor representation to a non-opaque format. 1176 class OpaqueElementsAttr 1177 : public Attribute::AttrBase<OpaqueElementsAttr, ElementsAttr, 1178 detail::OpaqueElementsAttributeStorage> { 1179 public: 1180 using Base::Base; 1181 using ValueType = StringRef; 1182 1183 static OpaqueElementsAttr get(Dialect *dialect, ShapedType type, 1184 StringRef bytes); 1185 1186 StringRef getValue() const; 1187 1188 /// Return the value at the given index. The 'index' is expected to refer to a 1189 /// valid element. 1190 Attribute getValue(ArrayRef<uint64_t> index) const; 1191 1192 /// Decodes the attribute value using dialect-specific decoding hook. 1193 /// Returns false if decoding is successful. If not, returns true and leaves 1194 /// 'result' argument unspecified. 1195 bool decode(ElementsAttr &result); 1196 1197 /// Returns dialect associated with this opaque constant. 1198 Dialect *getDialect() const; 1199 }; 1200 1201 /// An attribute that represents a reference to a sparse vector or tensor 1202 /// object. 1203 /// 1204 /// This class uses COO (coordinate list) encoding to represent the sparse 1205 /// elements in an element attribute. Specifically, the sparse vector/tensor 1206 /// stores the indices and values as two separate dense elements attributes of 1207 /// tensor type (even if the sparse attribute is of vector type, in order to 1208 /// support empty lists). The dense elements attribute indices is a 2-D tensor 1209 /// of 64-bit integer elements with shape [N, ndims], which specifies the 1210 /// indices of the elements in the sparse tensor that contains nonzero values. 1211 /// The dense elements attribute values is a 1-D tensor with shape [N], and it 1212 /// supplies the corresponding values for the indices. 1213 /// 1214 /// For example, 1215 /// `sparse<tensor<3x4xi32>, [[0, 0], [1, 2]], [1, 5]>` represents tensor 1216 /// [[1, 0, 0, 0], 1217 /// [0, 0, 5, 0], 1218 /// [0, 0, 0, 0]]. 1219 class SparseElementsAttr 1220 : public Attribute::AttrBase<SparseElementsAttr, ElementsAttr, 1221 detail::SparseElementsAttributeStorage> { 1222 public: 1223 using Base::Base; 1224 1225 template <typename T> 1226 using iterator = 1227 llvm::mapped_iterator<llvm::detail::value_sequence_iterator<ptrdiff_t>, 1228 std::function<T(ptrdiff_t)>>; 1229 1230 /// 'type' must be a vector or tensor with static shape. 1231 static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices, 1232 DenseElementsAttr values); 1233 1234 DenseIntElementsAttr getIndices() const; 1235 1236 DenseElementsAttr getValues() const; 1237 1238 /// Return the values of this attribute in the form of the given type 'T'. 'T' 1239 /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc. 1240 template <typename T> 1241 llvm::iterator_range<iterator<T>> getValues() const { 1242 auto zeroValue = getZeroValue<T>(); 1243 auto valueIt = getValues().getValues<T>().begin(); 1244 const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices()); 1245 // TODO: Move-capture flatSparseIndices when c++14 is available. 1246 std::function<T(ptrdiff_t)> mapFn = [=](ptrdiff_t index) { 1247 // Try to map the current index to one of the sparse indices. 1248 for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) 1249 if (flatSparseIndices[i] == index) 1250 return *std::next(valueIt, i); 1251 // Otherwise, return the zero value. 1252 return zeroValue; 1253 }; 1254 return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn); 1255 } 1256 1257 /// Return the value of the element at the given index. The 'index' is 1258 /// expected to refer to a valid element. 1259 Attribute getValue(ArrayRef<uint64_t> index) const; 1260 1261 private: 1262 /// Get a zero APFloat for the given sparse attribute. 1263 APFloat getZeroAPFloat() const; 1264 1265 /// Get a zero APInt for the given sparse attribute. 1266 APInt getZeroAPInt() const; 1267 1268 /// Get a zero attribute for the given sparse attribute. 1269 Attribute getZeroAttr() const; 1270 1271 /// Utility methods to generate a zero value of some type 'T'. This is used by 1272 /// the 'iterator' class. 1273 /// Get a zero for a given attribute type. 1274 template <typename T> 1275 typename std::enable_if<std::is_base_of<Attribute, T>::value, T>::type 1276 getZeroValue() const { 1277 return getZeroAttr().template cast<T>(); 1278 } 1279 /// Get a zero for an APInt. 1280 template <typename T> 1281 typename std::enable_if<std::is_same<APInt, T>::value, T>::type 1282 getZeroValue() const { 1283 return getZeroAPInt(); 1284 } 1285 template <typename T> 1286 typename std::enable_if<std::is_same<std::complex<APInt>, T>::value, T>::type 1287 getZeroValue() const { 1288 APInt intZero = getZeroAPInt(); 1289 return {intZero, intZero}; 1290 } 1291 /// Get a zero for an APFloat. 1292 template <typename T> 1293 typename std::enable_if<std::is_same<APFloat, T>::value, T>::type 1294 getZeroValue() const { 1295 return getZeroAPFloat(); 1296 } 1297 template <typename T> 1298 typename std::enable_if<std::is_same<std::complex<APFloat>, T>::value, 1299 T>::type 1300 getZeroValue() const { 1301 APFloat floatZero = getZeroAPFloat(); 1302 return {floatZero, floatZero}; 1303 } 1304 1305 /// Get a zero for an C++ integer, float, StringRef, or complex type. 1306 template <typename T> 1307 typename std::enable_if< 1308 std::numeric_limits<T>::is_integer || 1309 DenseElementsAttr::is_valid_cpp_fp_type<T>::value || 1310 std::is_same<T, StringRef>::value || 1311 (detail::is_complex_t<T>::value && 1312 !llvm::is_one_of<T, std::complex<APInt>, 1313 std::complex<APFloat>>::value), 1314 T>::type 1315 getZeroValue() const { 1316 return T(); 1317 } 1318 1319 /// Flatten, and return, all of the sparse indices in this attribute in 1320 /// row-major order. 1321 std::vector<ptrdiff_t> getFlattenedSparseIndices() const; 1322 }; 1323 1324 /// An attribute that represents a reference to a splat vector or tensor 1325 /// constant, meaning all of the elements have the same value. 1326 class SplatElementsAttr : public DenseElementsAttr { 1327 public: 1328 using DenseElementsAttr::DenseElementsAttr; 1329 1330 /// Method for support type inquiry through isa, cast and dyn_cast. 1331 static bool classof(Attribute attr) { 1332 auto denseAttr = attr.dyn_cast<DenseElementsAttr>(); 1333 return denseAttr && denseAttr.isSplat(); 1334 } 1335 }; 1336 1337 namespace detail { 1338 /// This class represents a general iterator over the values of an ElementsAttr. 1339 /// It supports all subclasses aside from OpaqueElementsAttr. 1340 template <typename T> 1341 class ElementsAttrIterator 1342 : public llvm::iterator_facade_base<ElementsAttrIterator<T>, 1343 std::random_access_iterator_tag, T, 1344 std::ptrdiff_t, T, T> { 1345 // NOTE: We use a dummy enable_if here because MSVC cannot use 'decltype' 1346 // inside of a conversion operator. 1347 using DenseIteratorT = typename std::enable_if< 1348 true, 1349 decltype(std::declval<DenseElementsAttr>().getValues<T>().begin())>::type; 1350 using SparseIteratorT = SparseElementsAttr::iterator<T>; 1351 1352 /// A union containing the specific iterators for each derived attribute kind. 1353 union Iterator { 1354 Iterator(DenseIteratorT &&it) : denseIt(std::move(it)) {} 1355 Iterator(SparseIteratorT &&it) : sparseIt(std::move(it)) {} 1356 Iterator() {} 1357 ~Iterator() {} 1358 1359 operator const DenseIteratorT &() const { return denseIt; } 1360 operator const SparseIteratorT &() const { return sparseIt; } 1361 operator DenseIteratorT &() { return denseIt; } 1362 operator SparseIteratorT &() { return sparseIt; } 1363 1364 /// An instance of a dense elements iterator. 1365 DenseIteratorT denseIt; 1366 /// An instance of a sparse elements iterator. 1367 SparseIteratorT sparseIt; 1368 }; 1369 1370 /// Utility method to process a functor on each of the internal iterator 1371 /// types. 1372 template <typename RetT, template <typename> class ProcessFn, 1373 typename... Args> 1374 RetT process(Args &...args) const { 1375 if (attr.isa<DenseElementsAttr>()) 1376 return ProcessFn<DenseIteratorT>()(args...); 1377 if (attr.isa<SparseElementsAttr>()) 1378 return ProcessFn<SparseIteratorT>()(args...); 1379 llvm_unreachable("unexpected attribute kind"); 1380 } 1381 1382 /// Utility functors used to generically implement the iterators methods. 1383 template <typename ItT> 1384 struct PlusAssign { 1385 void operator()(ItT &it, ptrdiff_t offset) { it += offset; } 1386 }; 1387 template <typename ItT> 1388 struct Minus { 1389 ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; } 1390 }; 1391 template <typename ItT> 1392 struct MinusAssign { 1393 void operator()(ItT &it, ptrdiff_t offset) { it -= offset; } 1394 }; 1395 template <typename ItT> 1396 struct Dereference { 1397 T operator()(ItT &it) { return *it; } 1398 }; 1399 template <typename ItT> 1400 struct ConstructIter { 1401 void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); } 1402 }; 1403 template <typename ItT> 1404 struct DestructIter { 1405 void operator()(ItT &it) { it.~ItT(); } 1406 }; 1407 1408 public: 1409 ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) { 1410 process<void, ConstructIter>(it, rhs.it); 1411 } 1412 ~ElementsAttrIterator() { process<void, DestructIter>(it); } 1413 1414 /// Methods necessary to support random access iteration. 1415 ptrdiff_t operator-(const ElementsAttrIterator<T> &rhs) const { 1416 assert(attr == rhs.attr && "incompatible iterators"); 1417 return process<ptrdiff_t, Minus>(it, rhs.it); 1418 } 1419 bool operator==(const ElementsAttrIterator<T> &rhs) const { 1420 return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it); 1421 } 1422 bool operator<(const ElementsAttrIterator<T> &rhs) const { 1423 assert(attr == rhs.attr && "incompatible iterators"); 1424 return process<bool, std::less>(it, rhs.it); 1425 } 1426 ElementsAttrIterator<T> &operator+=(ptrdiff_t offset) { 1427 process<void, PlusAssign>(it, offset); 1428 return *this; 1429 } 1430 ElementsAttrIterator<T> &operator-=(ptrdiff_t offset) { 1431 process<void, MinusAssign>(it, offset); 1432 return *this; 1433 } 1434 1435 /// Dereference the iterator at the current index. 1436 T operator*() { return process<T, Dereference>(it); } 1437 1438 private: 1439 template <typename IteratorT> 1440 ElementsAttrIterator(Attribute attr, IteratorT &&it) 1441 : attr(attr), it(std::forward<IteratorT>(it)) {} 1442 1443 /// Allow accessing the constructor. 1444 friend ElementsAttr; 1445 1446 /// The parent elements attribute. 1447 Attribute attr; 1448 1449 /// A union containing the specific iterators for each derived kind. 1450 Iterator it; 1451 }; 1452 1453 template <typename T> 1454 class ElementsAttrRange : public llvm::iterator_range<ElementsAttrIterator<T>> { 1455 using llvm::iterator_range<ElementsAttrIterator<T>>::iterator_range; 1456 }; 1457 } // namespace detail 1458 1459 /// Return the elements of this attribute as a value of type 'T'. 1460 template <typename T> 1461 auto ElementsAttr::getValues() const -> iterator_range<T> { 1462 if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) { 1463 auto values = denseAttr.getValues<T>(); 1464 return {iterator<T>(*this, values.begin()), 1465 iterator<T>(*this, values.end())}; 1466 } 1467 if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) { 1468 auto values = sparseAttr.getValues<T>(); 1469 return {iterator<T>(*this, values.begin()), 1470 iterator<T>(*this, values.end())}; 1471 } 1472 llvm_unreachable("unexpected attribute kind"); 1473 } 1474 1475 //===----------------------------------------------------------------------===// 1476 // MutableDictionaryAttr 1477 //===----------------------------------------------------------------------===// 1478 1479 /// A MutableDictionaryAttr is a mutable wrapper around a DictionaryAttr. It 1480 /// provides additional interfaces for adding, removing, replacing attributes 1481 /// within a DictionaryAttr. 1482 /// 1483 /// We assume there will be relatively few attributes on a given operation 1484 /// (maybe a dozen or so, but not hundreds or thousands) so we use linear 1485 /// searches for everything. 1486 class MutableDictionaryAttr { 1487 public: 1488 MutableDictionaryAttr(DictionaryAttr attrs = nullptr) 1489 : attrs((attrs && !attrs.empty()) ? attrs : nullptr) {} 1490 MutableDictionaryAttr(ArrayRef<NamedAttribute> attributes); 1491 1492 bool operator!=(const MutableDictionaryAttr &other) const { 1493 return !(*this == other); 1494 } 1495 bool operator==(const MutableDictionaryAttr &other) const { 1496 return attrs == other.attrs; 1497 } 1498 1499 /// Return the underlying dictionary attribute. 1500 DictionaryAttr getDictionary(MLIRContext *context) const; 1501 1502 /// Return the underlying dictionary attribute or null if there are no 1503 /// attributes within this dictionary. 1504 DictionaryAttr getDictionaryOrNull() const { return attrs; } 1505 1506 /// Return all of the attributes on this operation. 1507 ArrayRef<NamedAttribute> getAttrs() const; 1508 1509 /// Replace the held attributes with ones provided in 'newAttrs'. 1510 void setAttrs(ArrayRef<NamedAttribute> attributes); 1511 1512 /// Return the specified attribute if present, null otherwise. 1513 Attribute get(StringRef name) const; 1514 Attribute get(Identifier name) const; 1515 1516 /// Return the specified named attribute if present, None otherwise. 1517 Optional<NamedAttribute> getNamed(StringRef name) const; 1518 Optional<NamedAttribute> getNamed(Identifier name) const; 1519 1520 /// If the an attribute exists with the specified name, change it to the new 1521 /// value. Otherwise, add a new attribute with the specified name/value. 1522 void set(Identifier name, Attribute value); 1523 1524 enum class RemoveResult { Removed, NotFound }; 1525 1526 /// Remove the attribute with the specified name if it exists. The return 1527 /// value indicates whether the attribute was present or not. 1528 RemoveResult remove(Identifier name); 1529 1530 bool empty() const { return attrs == nullptr; } 1531 1532 private: 1533 friend ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg); 1534 1535 DictionaryAttr attrs; 1536 }; 1537 1538 inline ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg) { 1539 if (!arg.attrs) 1540 return ::llvm::hash_value((void *)nullptr); 1541 return hash_value(arg.attrs); 1542 } 1543 1544 } // end namespace mlir. 1545 1546 namespace llvm { 1547 1548 template <> 1549 struct PointerLikeTypeTraits<mlir::SymbolRefAttr> 1550 : public PointerLikeTypeTraits<mlir::Attribute> { 1551 static inline mlir::SymbolRefAttr getFromVoidPointer(void *ptr) { 1552 return PointerLikeTypeTraits<mlir::Attribute>::getFromVoidPointer(ptr) 1553 .cast<mlir::SymbolRefAttr>(); 1554 } 1555 }; 1556 1557 } // namespace llvm 1558 1559 #endif // MLIR_IR_BUILTINATTRIBUTES_H 1560