1 //===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This holds implementation details of Attribute.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef ATTRIBUTEDETAIL_H_
14 #define ATTRIBUTEDETAIL_H_
15
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Identifier.h"
20 #include "mlir/IR/IntegerSet.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Support/StorageUniquer.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/PointerIntPair.h"
25 #include "llvm/Support/TrailingObjects.h"
26
27 namespace mlir {
28 namespace detail {
29 // An attribute representing a reference to an affine map.
30 struct AffineMapAttributeStorage : public AttributeStorage {
31 using KeyTy = AffineMap;
32
AffineMapAttributeStorageAffineMapAttributeStorage33 AffineMapAttributeStorage(AffineMap value)
34 : AttributeStorage(IndexType::get(value.getContext())), value(value) {}
35
36 /// Key equality function.
37 bool operator==(const KeyTy &key) const { return key == value; }
38
39 /// Construct a new storage instance.
40 static AffineMapAttributeStorage *
constructAffineMapAttributeStorage41 construct(AttributeStorageAllocator &allocator, KeyTy key) {
42 return new (allocator.allocate<AffineMapAttributeStorage>())
43 AffineMapAttributeStorage(key);
44 }
45
46 AffineMap value;
47 };
48
49 /// An attribute representing an array of other attributes.
50 struct ArrayAttributeStorage : public AttributeStorage {
51 using KeyTy = ArrayRef<Attribute>;
52
ArrayAttributeStorageArrayAttributeStorage53 ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {}
54
55 /// Key equality function.
56 bool operator==(const KeyTy &key) const { return key == value; }
57
58 /// Construct a new storage instance.
constructArrayAttributeStorage59 static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
60 const KeyTy &key) {
61 return new (allocator.allocate<ArrayAttributeStorage>())
62 ArrayAttributeStorage(allocator.copyInto(key));
63 }
64
65 ArrayRef<Attribute> value;
66 };
67
68 /// An attribute representing a dictionary of sorted named attributes.
69 struct DictionaryAttributeStorage final
70 : public AttributeStorage,
71 private llvm::TrailingObjects<DictionaryAttributeStorage,
72 NamedAttribute> {
73 using KeyTy = ArrayRef<NamedAttribute>;
74
75 /// Given a list of NamedAttribute's, canonicalize the list (sorting
76 /// by name) and return the unique'd result.
77 static DictionaryAttributeStorage *get(ArrayRef<NamedAttribute> attrs);
78
79 /// Key equality function.
80 bool operator==(const KeyTy &key) const { return key == getElements(); }
81
82 /// Construct a new storage instance.
83 static DictionaryAttributeStorage *
constructfinal84 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
85 auto size = DictionaryAttributeStorage::totalSizeToAlloc<NamedAttribute>(
86 key.size());
87 auto rawMem = allocator.allocate(size, alignof(DictionaryAttributeStorage));
88
89 // Initialize the storage and trailing attribute list.
90 auto result = ::new (rawMem) DictionaryAttributeStorage(key.size());
91 std::uninitialized_copy(key.begin(), key.end(),
92 result->getTrailingObjects<NamedAttribute>());
93 return result;
94 }
95
96 /// Return the elements of this dictionary attribute.
getElementsfinal97 ArrayRef<NamedAttribute> getElements() const {
98 return {getTrailingObjects<NamedAttribute>(), numElements};
99 }
100
101 private:
102 friend class llvm::TrailingObjects<DictionaryAttributeStorage,
103 NamedAttribute>;
104
105 // This is used by the llvm::TrailingObjects base class.
numTrailingObjectsfinal106 size_t numTrailingObjects(OverloadToken<NamedAttribute>) const {
107 return numElements;
108 }
DictionaryAttributeStoragefinal109 DictionaryAttributeStorage(unsigned numElements) : numElements(numElements) {}
110
111 /// This is the number of attributes.
112 const unsigned numElements;
113 };
114
115 /// An attribute representing a floating point value.
116 struct FloatAttributeStorage final
117 : public AttributeStorage,
118 public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
119 using KeyTy = std::pair<Type, APFloat>;
120
FloatAttributeStoragefinal121 FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
122 size_t numObjects)
123 : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
124
125 /// Key equality and hash functions.
126 bool operator==(const KeyTy &key) const {
127 return key.first == getType() && key.second.bitwiseIsEqual(getValue());
128 }
hashKeyfinal129 static unsigned hashKey(const KeyTy &key) {
130 return llvm::hash_combine(key.first, llvm::hash_value(key.second));
131 }
132
133 /// Construct a key with a type and double.
getKeyfinal134 static KeyTy getKey(Type type, double value) {
135 if (type.isF64())
136 return KeyTy(type, APFloat(value));
137
138 // This handles, e.g., F16 because there is no APFloat constructor for it.
139 bool unused;
140 APFloat val(value);
141 val.convert(type.cast<FloatType>().getFloatSemantics(),
142 APFloat::rmNearestTiesToEven, &unused);
143 return KeyTy(type, val);
144 }
145
146 /// Construct a new storage instance.
constructfinal147 static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator,
148 const KeyTy &key) {
149 const auto &apint = key.second.bitcastToAPInt();
150
151 // Here one word's bitwidth equals to that of uint64_t.
152 auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
153
154 auto byteSize =
155 FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
156 auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage));
157 auto result = ::new (rawMem) FloatAttributeStorage(
158 key.second.getSemantics(), key.first, elements.size());
159 std::uninitialized_copy(elements.begin(), elements.end(),
160 result->getTrailingObjects<uint64_t>());
161 return result;
162 }
163
164 /// Returns an APFloat representing the stored value.
getValuefinal165 APFloat getValue() const {
166 auto val = APInt(APFloat::getSizeInBits(semantics),
167 {getTrailingObjects<uint64_t>(), numObjects});
168 return APFloat(semantics, val);
169 }
170
171 const llvm::fltSemantics &semantics;
172 size_t numObjects;
173 };
174
175 /// An attribute representing an integral value.
176 struct IntegerAttributeStorage final
177 : public AttributeStorage,
178 public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
179 using KeyTy = std::pair<Type, APInt>;
180
IntegerAttributeStoragefinal181 IntegerAttributeStorage(Type type, size_t numObjects)
182 : AttributeStorage(type), numObjects(numObjects) {
183 assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
184 }
185
186 /// Key equality and hash functions.
187 bool operator==(const KeyTy &key) const {
188 return key == KeyTy(getType(), getValue());
189 }
hashKeyfinal190 static unsigned hashKey(const KeyTy &key) {
191 return llvm::hash_combine(key.first, llvm::hash_value(key.second));
192 }
193
194 /// Construct a new storage instance.
195 static IntegerAttributeStorage *
constructfinal196 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
197 Type type;
198 APInt value;
199 std::tie(type, value) = key;
200
201 auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
202 auto size =
203 IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
204 auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage));
205 auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size());
206 std::uninitialized_copy(elements.begin(), elements.end(),
207 result->getTrailingObjects<uint64_t>());
208 return result;
209 }
210
211 /// Returns an APInt representing the stored value.
getValuefinal212 APInt getValue() const {
213 if (getType().isIndex())
214 return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
215 return APInt(getType().getIntOrFloatBitWidth(),
216 {getTrailingObjects<uint64_t>(), numObjects});
217 }
218
219 size_t numObjects;
220 };
221
222 // An attribute representing a reference to an integer set.
223 struct IntegerSetAttributeStorage : public AttributeStorage {
224 using KeyTy = IntegerSet;
225
IntegerSetAttributeStorageIntegerSetAttributeStorage226 IntegerSetAttributeStorage(IntegerSet value) : value(value) {}
227
228 /// Key equality function.
229 bool operator==(const KeyTy &key) const { return key == value; }
230
231 /// Construct a new storage instance.
232 static IntegerSetAttributeStorage *
constructIntegerSetAttributeStorage233 construct(AttributeStorageAllocator &allocator, KeyTy key) {
234 return new (allocator.allocate<IntegerSetAttributeStorage>())
235 IntegerSetAttributeStorage(key);
236 }
237
238 IntegerSet value;
239 };
240
241 /// Opaque Attribute Storage and Uniquing.
242 struct OpaqueAttributeStorage : public AttributeStorage {
OpaqueAttributeStorageOpaqueAttributeStorage243 OpaqueAttributeStorage(Identifier dialectNamespace, StringRef attrData,
244 Type type)
245 : AttributeStorage(type), dialectNamespace(dialectNamespace),
246 attrData(attrData) {}
247
248 /// The hash key used for uniquing.
249 using KeyTy = std::tuple<Identifier, StringRef, Type>;
250 bool operator==(const KeyTy &key) const {
251 return key == KeyTy(dialectNamespace, attrData, getType());
252 }
253
constructOpaqueAttributeStorage254 static OpaqueAttributeStorage *construct(AttributeStorageAllocator &allocator,
255 const KeyTy &key) {
256 return new (allocator.allocate<OpaqueAttributeStorage>())
257 OpaqueAttributeStorage(std::get<0>(key),
258 allocator.copyInto(std::get<1>(key)),
259 std::get<2>(key));
260 }
261
262 // The dialect namespace.
263 Identifier dialectNamespace;
264
265 // The parser attribute data for this opaque attribute.
266 StringRef attrData;
267 };
268
269 /// An attribute representing a string value.
270 struct StringAttributeStorage : public AttributeStorage {
271 using KeyTy = std::pair<StringRef, Type>;
272
StringAttributeStorageStringAttributeStorage273 StringAttributeStorage(StringRef value, Type type)
274 : AttributeStorage(type), value(value) {}
275
276 /// Key equality function.
277 bool operator==(const KeyTy &key) const {
278 return key == KeyTy(value, getType());
279 }
280
281 /// Construct a new storage instance.
constructStringAttributeStorage282 static StringAttributeStorage *construct(AttributeStorageAllocator &allocator,
283 const KeyTy &key) {
284 return new (allocator.allocate<StringAttributeStorage>())
285 StringAttributeStorage(allocator.copyInto(key.first), key.second);
286 }
287
288 StringRef value;
289 };
290
291 /// An attribute representing a symbol reference.
292 struct SymbolRefAttributeStorage final
293 : public AttributeStorage,
294 public llvm::TrailingObjects<SymbolRefAttributeStorage,
295 FlatSymbolRefAttr> {
296 using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>;
297
SymbolRefAttributeStoragefinal298 SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs)
299 : value(value), numNestedRefs(numNestedRefs) {}
300
301 /// Key equality function.
302 bool operator==(const KeyTy &key) const {
303 return key == KeyTy(value, getNestedRefs());
304 }
305
306 /// Construct a new storage instance.
307 static SymbolRefAttributeStorage *
constructfinal308 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
309 auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>(
310 key.second.size());
311 auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage));
312 auto result = ::new (rawMem) SymbolRefAttributeStorage(
313 allocator.copyInto(key.first), key.second.size());
314 std::uninitialized_copy(key.second.begin(), key.second.end(),
315 result->getTrailingObjects<FlatSymbolRefAttr>());
316 return result;
317 }
318
319 /// Returns the set of nested references.
getNestedRefsfinal320 ArrayRef<FlatSymbolRefAttr> getNestedRefs() const {
321 return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs};
322 }
323
324 StringRef value;
325 size_t numNestedRefs;
326 };
327
328 /// An attribute representing a reference to a type.
329 struct TypeAttributeStorage : public AttributeStorage {
330 using KeyTy = Type;
331
TypeAttributeStorageTypeAttributeStorage332 TypeAttributeStorage(Type value) : value(value) {}
333
334 /// Key equality function.
335 bool operator==(const KeyTy &key) const { return key == value; }
336
337 /// Construct a new storage instance.
constructTypeAttributeStorage338 static TypeAttributeStorage *construct(AttributeStorageAllocator &allocator,
339 KeyTy key) {
340 return new (allocator.allocate<TypeAttributeStorage>())
341 TypeAttributeStorage(key);
342 }
343
344 Type value;
345 };
346
347 //===----------------------------------------------------------------------===//
348 // Elements Attributes
349 //===----------------------------------------------------------------------===//
350
351 /// Return the bit width which DenseElementsAttr should use for this type.
getDenseElementBitWidth(Type eltType)352 inline size_t getDenseElementBitWidth(Type eltType) {
353 // Align the width for complex to 8 to make storage and interpretation easier.
354 if (ComplexType comp = eltType.dyn_cast<ComplexType>())
355 return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
356 if (eltType.isIndex())
357 return IndexType::kInternalStorageBitWidth;
358 return eltType.getIntOrFloatBitWidth();
359 }
360
361 /// An attribute representing a reference to a dense vector or tensor object.
362 struct DenseElementsAttributeStorage : public AttributeStorage {
363 public:
DenseElementsAttributeStorageDenseElementsAttributeStorage364 DenseElementsAttributeStorage(ShapedType ty, bool isSplat)
365 : AttributeStorage(ty), isSplat(isSplat) {}
366
367 bool isSplat;
368 };
369
370 /// An attribute representing a reference to a dense vector or tensor object.
371 struct DenseIntOrFPElementsAttributeStorage
372 : public DenseElementsAttributeStorage {
373 DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef<char> data,
374 bool isSplat = false)
DenseElementsAttributeStorageDenseIntOrFPElementsAttributeStorage375 : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
376
377 struct KeyTy {
378 KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
379 bool isSplat = false)
typeDenseIntOrFPElementsAttributeStorage::KeyTy380 : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
381
382 /// The type of the dense elements.
383 ShapedType type;
384
385 /// The raw buffer for the data storage.
386 ArrayRef<char> data;
387
388 /// The computed hash code for the storage data.
389 llvm::hash_code hashCode;
390
391 /// A boolean that indicates if this data is a splat or not.
392 bool isSplat;
393 };
394
395 /// Compare this storage instance with the provided key.
396 bool operator==(const KeyTy &key) const {
397 if (key.type != getType())
398 return false;
399
400 // For boolean splats we need to explicitly check that the first bit is the
401 // same. Boolean values are packed at the bit level, and even though a splat
402 // is detected the rest of the bits in the first byte may differ from the
403 // splat value.
404 if (key.type.getElementType().isInteger(1)) {
405 if (key.isSplat != isSplat)
406 return false;
407 if (isSplat)
408 return (key.data.front() & 1) == data.front();
409 }
410
411 // Otherwise, we can default to just checking the data.
412 return key.data == data;
413 }
414
415 /// Construct a key from a shaped type, raw data buffer, and a flag that
416 /// signals if the data is already known to be a splat. Callers to this
417 /// function are expected to tag preknown splat values when possible, e.g. one
418 /// element shapes.
getKeyDenseIntOrFPElementsAttributeStorage419 static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
420 // Handle an empty storage instance.
421 if (data.empty())
422 return KeyTy(ty, data, 0);
423
424 // If the data is already known to be a splat, the key hash value is
425 // directly the data buffer.
426 if (isKnownSplat)
427 return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
428
429 // Otherwise, we need to check if the data corresponds to a splat or not.
430
431 // Handle the simple case of only one element.
432 size_t numElements = ty.getNumElements();
433 assert(numElements != 1 && "splat of 1 element should already be detected");
434
435 // Handle boolean values directly as they are packed to 1-bit.
436 if (ty.getElementType().isInteger(1) == 1)
437 return getKeyForBoolData(ty, data, numElements);
438
439 size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
440 // Non 1-bit dense elements are padded to 8-bits.
441 size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
442 assert(((data.size() / storageSize) == numElements) &&
443 "data does not hold expected number of elements");
444
445 // Create the initial hash value with just the first element.
446 auto firstElt = data.take_front(storageSize);
447 auto hashVal = llvm::hash_value(firstElt);
448
449 // Check to see if this storage represents a splat. If it doesn't then
450 // combine the hash for the data starting with the first non splat element.
451 for (size_t i = storageSize, e = data.size(); i != e; i += storageSize)
452 if (memcmp(data.data(), &data[i], storageSize))
453 return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
454
455 // Otherwise, this is a splat so just return the hash of the first element.
456 return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
457 }
458
459 /// Construct a key with a set of boolean data.
getKeyForBoolDataDenseIntOrFPElementsAttributeStorage460 static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
461 size_t numElements) {
462 ArrayRef<char> splatData = data;
463 bool splatValue = splatData.front() & 1;
464
465 // Helper functor to generate a KeyTy for a boolean splat value.
466 auto generateSplatKey = [=] {
467 return KeyTy(ty, data.take_front(1),
468 llvm::hash_value(ArrayRef<char>(splatValue ? 1 : 0)),
469 /*isSplat=*/true);
470 };
471
472 // Handle the case where the potential splat value is 1 and the number of
473 // elements is non 8-bit aligned.
474 size_t numOddElements = numElements % CHAR_BIT;
475 if (splatValue && numOddElements != 0) {
476 // Check that all bits are set in the last value.
477 char lastElt = splatData.back();
478 if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
479 return KeyTy(ty, data, llvm::hash_value(data));
480
481 // If this is the only element, the data is known to be a splat.
482 if (splatData.size() == 1)
483 return generateSplatKey();
484 splatData = splatData.drop_back();
485 }
486
487 // Check that the data buffer corresponds to a splat of the proper mask.
488 char mask = splatValue ? ~0 : 0;
489 return llvm::all_of(splatData, [mask](char c) { return c == mask; })
490 ? generateSplatKey()
491 : KeyTy(ty, data, llvm::hash_value(data));
492 }
493
494 /// Hash the key for the storage.
hashKeyDenseIntOrFPElementsAttributeStorage495 static llvm::hash_code hashKey(const KeyTy &key) {
496 return llvm::hash_combine(key.type, key.hashCode);
497 }
498
499 /// Construct a new storage instance.
500 static DenseIntOrFPElementsAttributeStorage *
constructDenseIntOrFPElementsAttributeStorage501 construct(AttributeStorageAllocator &allocator, KeyTy key) {
502 // If the data buffer is non-empty, we copy it into the allocator with a
503 // 64-bit alignment.
504 ArrayRef<char> copy, data = key.data;
505 if (!data.empty()) {
506 char *rawData = reinterpret_cast<char *>(
507 allocator.allocate(data.size(), alignof(uint64_t)));
508 std::memcpy(rawData, data.data(), data.size());
509
510 // If this is a boolean splat, make sure only the first bit is used.
511 if (key.isSplat && key.type.getElementType().isInteger(1))
512 rawData[0] &= 1;
513 copy = ArrayRef<char>(rawData, data.size());
514 }
515
516 return new (allocator.allocate<DenseIntOrFPElementsAttributeStorage>())
517 DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat);
518 }
519
520 ArrayRef<char> data;
521 };
522
523 /// An attribute representing a reference to a dense vector or tensor object
524 /// containing strings.
525 struct DenseStringElementsAttributeStorage
526 : public DenseElementsAttributeStorage {
527 DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef<StringRef> data,
528 bool isSplat = false)
DenseElementsAttributeStorageDenseStringElementsAttributeStorage529 : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
530
531 struct KeyTy {
532 KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
533 bool isSplat = false)
typeDenseStringElementsAttributeStorage::KeyTy534 : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
535
536 /// The type of the dense elements.
537 ShapedType type;
538
539 /// The raw buffer for the data storage.
540 ArrayRef<StringRef> data;
541
542 /// The computed hash code for the storage data.
543 llvm::hash_code hashCode;
544
545 /// A boolean that indicates if this data is a splat or not.
546 bool isSplat;
547 };
548
549 /// Compare this storage instance with the provided key.
550 bool operator==(const KeyTy &key) const {
551 if (key.type != getType())
552 return false;
553
554 // Otherwise, we can default to just checking the data. StringRefs compare
555 // by contents.
556 return key.data == data;
557 }
558
559 /// Construct a key from a shaped type, StringRef data buffer, and a flag that
560 /// signals if the data is already known to be a splat. Callers to this
561 /// function are expected to tag preknown splat values when possible, e.g. one
562 /// element shapes.
getKeyDenseStringElementsAttributeStorage563 static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
564 bool isKnownSplat) {
565 // Handle an empty storage instance.
566 if (data.empty())
567 return KeyTy(ty, data, 0);
568
569 // If the data is already known to be a splat, the key hash value is
570 // directly the data buffer.
571 if (isKnownSplat)
572 return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
573
574 // Handle the simple case of only one element.
575 assert(ty.getNumElements() != 1 &&
576 "splat of 1 element should already be detected");
577
578 // Create the initial hash value with just the first element.
579 const auto &firstElt = data.front();
580 auto hashVal = llvm::hash_value(firstElt);
581
582 // Check to see if this storage represents a splat. If it doesn't then
583 // combine the hash for the data starting with the first non splat element.
584 for (size_t i = 1, e = data.size(); i != e; i++)
585 if (!firstElt.equals(data[i]))
586 return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
587
588 // Otherwise, this is a splat so just return the hash of the first element.
589 return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
590 }
591
592 /// Hash the key for the storage.
hashKeyDenseStringElementsAttributeStorage593 static llvm::hash_code hashKey(const KeyTy &key) {
594 return llvm::hash_combine(key.type, key.hashCode);
595 }
596
597 /// Construct a new storage instance.
598 static DenseStringElementsAttributeStorage *
constructDenseStringElementsAttributeStorage599 construct(AttributeStorageAllocator &allocator, KeyTy key) {
600 // If the data buffer is non-empty, we copy it into the allocator with a
601 // 64-bit alignment.
602 ArrayRef<StringRef> copy, data = key.data;
603 if (data.empty()) {
604 return new (allocator.allocate<DenseStringElementsAttributeStorage>())
605 DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
606 }
607
608 int numEntries = key.isSplat ? 1 : data.size();
609
610 // Compute the amount data needed to store the ArrayRef and StringRef
611 // contents.
612 size_t dataSize = sizeof(StringRef) * numEntries;
613 for (int i = 0; i < numEntries; i++)
614 dataSize += data[i].size();
615
616 char *rawData = reinterpret_cast<char *>(
617 allocator.allocate(dataSize, alignof(uint64_t)));
618
619 // Setup a mutable array ref of our string refs so that we can update their
620 // contents.
621 auto mutableCopy = MutableArrayRef<StringRef>(
622 reinterpret_cast<StringRef *>(rawData), numEntries);
623 auto stringData = rawData + numEntries * sizeof(StringRef);
624
625 for (int i = 0; i < numEntries; i++) {
626 memcpy(stringData, data[i].data(), data[i].size());
627 mutableCopy[i] = StringRef(stringData, data[i].size());
628 stringData += data[i].size();
629 }
630
631 copy =
632 ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
633
634 return new (allocator.allocate<DenseStringElementsAttributeStorage>())
635 DenseStringElementsAttributeStorage(key.type, copy, key.isSplat);
636 }
637
638 ArrayRef<StringRef> data;
639 };
640
641 /// An attribute representing a reference to a tensor constant with opaque
642 /// content.
643 struct OpaqueElementsAttributeStorage : public AttributeStorage {
644 using KeyTy = std::tuple<Type, Dialect *, StringRef>;
645
OpaqueElementsAttributeStorageOpaqueElementsAttributeStorage646 OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
647 : AttributeStorage(type), dialect(dialect), bytes(bytes) {}
648
649 /// Key equality and hash functions.
650 bool operator==(const KeyTy &key) const {
651 return key == std::make_tuple(getType(), dialect, bytes);
652 }
hashKeyOpaqueElementsAttributeStorage653 static unsigned hashKey(const KeyTy &key) {
654 return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
655 std::get<2>(key));
656 }
657
658 /// Construct a new storage instance.
659 static OpaqueElementsAttributeStorage *
constructOpaqueElementsAttributeStorage660 construct(AttributeStorageAllocator &allocator, KeyTy key) {
661 // TODO: Provide a way to avoid copying content of large opaque
662 // tensors This will likely require a new reference attribute kind.
663 return new (allocator.allocate<OpaqueElementsAttributeStorage>())
664 OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
665 allocator.copyInto(std::get<2>(key)));
666 }
667
668 Dialect *dialect;
669 StringRef bytes;
670 };
671
672 /// An attribute representing a reference to a sparse vector or tensor object.
673 struct SparseElementsAttributeStorage : public AttributeStorage {
674 using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
675
SparseElementsAttributeStorageSparseElementsAttributeStorage676 SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
677 DenseElementsAttr values)
678 : AttributeStorage(type), indices(indices), values(values) {}
679
680 /// Key equality and hash functions.
681 bool operator==(const KeyTy &key) const {
682 return key == std::make_tuple(getType(), indices, values);
683 }
hashKeySparseElementsAttributeStorage684 static unsigned hashKey(const KeyTy &key) {
685 return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
686 std::get<2>(key));
687 }
688
689 /// Construct a new storage instance.
690 static SparseElementsAttributeStorage *
constructSparseElementsAttributeStorage691 construct(AttributeStorageAllocator &allocator, KeyTy key) {
692 return new (allocator.allocate<SparseElementsAttributeStorage>())
693 SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key),
694 std::get<2>(key));
695 }
696
697 DenseIntElementsAttr indices;
698 DenseElementsAttr values;
699 };
700 } // namespace detail
701 } // namespace mlir
702
703 #endif // ATTRIBUTEDETAIL_H_
704