1 //===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_IR_BUILTINTYPES_H
10 #define MLIR_IR_BUILTINTYPES_H
11
12 #include "mlir/IR/Types.h"
13
14 namespace llvm {
15 struct fltSemantics;
16 } // namespace llvm
17
18 namespace mlir {
19 class AffineExpr;
20 class AffineMap;
21 class FloatType;
22 class Identifier;
23 class IndexType;
24 class IntegerType;
25 class Location;
26 class MLIRContext;
27 class TypeRange;
28
29 namespace detail {
30
31 struct BaseMemRefTypeStorage;
32 struct ComplexTypeStorage;
33 struct FunctionTypeStorage;
34 struct IntegerTypeStorage;
35 struct MemRefTypeStorage;
36 struct OpaqueTypeStorage;
37 struct RankedTensorTypeStorage;
38 struct ShapedTypeStorage;
39 struct TupleTypeStorage;
40 struct UnrankedMemRefTypeStorage;
41 struct UnrankedTensorTypeStorage;
42 struct VectorTypeStorage;
43
44 } // namespace detail
45
46 //===----------------------------------------------------------------------===//
47 // ComplexType
48 //===----------------------------------------------------------------------===//
49
50 /// The 'complex' type represents a complex number with a parameterized element
51 /// type, which is composed of a real and imaginary value of that element type.
52 ///
53 /// The element must be a floating point or integer scalar type.
54 ///
55 class ComplexType
56 : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
57 public:
58 using Base::Base;
59
60 /// Get or create a ComplexType with the provided element type.
61 static ComplexType get(Type elementType);
62
63 /// Get or create a ComplexType with the provided element type. This emits
64 /// and error at the specified location and returns null if the element type
65 /// isn't supported.
66 static ComplexType getChecked(Type elementType, Location location);
67
68 /// Verify the construction of an integer type.
69 static LogicalResult verifyConstructionInvariants(Location loc,
70 Type elementType);
71
72 Type getElementType();
73 };
74
75 //===----------------------------------------------------------------------===//
76 // IndexType
77 //===----------------------------------------------------------------------===//
78
79 /// Index is a special integer-like type with unknown platform-dependent bit
80 /// width.
81 class IndexType : public Type::TypeBase<IndexType, Type, TypeStorage> {
82 public:
83 using Base::Base;
84
85 /// Get an instance of the IndexType.
86 static IndexType get(MLIRContext *context);
87
88 /// Storage bit width used for IndexType by internal compiler data structures.
89 static constexpr unsigned kInternalStorageBitWidth = 64;
90 };
91
92 //===----------------------------------------------------------------------===//
93 // IntegerType
94 //===----------------------------------------------------------------------===//
95
96 /// Integer types can have arbitrary bitwidth up to a large fixed limit.
97 class IntegerType
98 : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
99 public:
100 using Base::Base;
101
102 /// Signedness semantics.
103 enum SignednessSemantics : uint32_t {
104 Signless, /// No signedness semantics
105 Signed, /// Signed integer
106 Unsigned, /// Unsigned integer
107 };
108
109 /// Get or create a new IntegerType of the given width within the context.
110 /// The created IntegerType is signless (i.e., no signedness semantics).
111 /// Assume the width is within the allowed range and assert on failures. Use
112 /// getChecked to handle failures gracefully.
113 static IntegerType get(unsigned width, MLIRContext *context);
114
115 /// Get or create a new IntegerType of the given width within the context.
116 /// The created IntegerType has signedness semantics as indicated via
117 /// `signedness`. Assume the width is within the allowed range and assert on
118 /// failures. Use getChecked to handle failures gracefully.
119 static IntegerType get(unsigned width, SignednessSemantics signedness,
120 MLIRContext *context);
121
122 /// Get or create a new IntegerType of the given width within the context,
123 /// defined at the given, potentially unknown, location. The created
124 /// IntegerType is signless (i.e., no signedness semantics). If the width is
125 /// outside the allowed range, emit errors and return a null type.
126 static IntegerType getChecked(unsigned width, Location location);
127
128 /// Get or create a new IntegerType of the given width within the context,
129 /// defined at the given, potentially unknown, location. The created
130 /// IntegerType has signedness semantics as indicated via `signedness`. If the
131 /// width is outside the allowed range, emit errors and return a null type.
132 static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
133 Location location);
134
135 /// Verify the construction of an integer type.
136 static LogicalResult
137 verifyConstructionInvariants(Location loc, unsigned width,
138 SignednessSemantics signedness);
139
140 /// Return the bitwidth of this integer type.
141 unsigned getWidth() const;
142
143 /// Return the signedness semantics of this integer type.
144 SignednessSemantics getSignedness() const;
145
146 /// Return true if this is a signless integer type.
isSignless()147 bool isSignless() const { return getSignedness() == Signless; }
148 /// Return true if this is a signed integer type.
isSigned()149 bool isSigned() const { return getSignedness() == Signed; }
150 /// Return true if this is an unsigned integer type.
isUnsigned()151 bool isUnsigned() const { return getSignedness() == Unsigned; }
152
153 /// Integer representation maximal bitwidth.
154 static constexpr unsigned kMaxWidth = 4096;
155 };
156
157 //===----------------------------------------------------------------------===//
158 // FloatType
159 //===----------------------------------------------------------------------===//
160
161 class FloatType : public Type {
162 public:
163 using Type::Type;
164
165 // Convenience factories.
166 static FloatType getBF16(MLIRContext *ctx);
167 static FloatType getF16(MLIRContext *ctx);
168 static FloatType getF32(MLIRContext *ctx);
169 static FloatType getF64(MLIRContext *ctx);
170
171 /// Methods for support type inquiry through isa, cast, and dyn_cast.
172 static bool classof(Type type);
173
174 /// Return the bitwidth of this float type.
175 unsigned getWidth();
176
177 /// Return the floating semantics of this float type.
178 const llvm::fltSemantics &getFloatSemantics();
179 };
180
181 //===----------------------------------------------------------------------===//
182 // BFloat16Type
183
184 class BFloat16Type
185 : public Type::TypeBase<BFloat16Type, FloatType, TypeStorage> {
186 public:
187 using Base::Base;
188
189 /// Return an instance of the bfloat16 type.
190 static BFloat16Type get(MLIRContext *context);
191 };
192
getBF16(MLIRContext * ctx)193 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
194 return BFloat16Type::get(ctx);
195 }
196
197 //===----------------------------------------------------------------------===//
198 // Float16Type
199
200 class Float16Type : public Type::TypeBase<Float16Type, FloatType, TypeStorage> {
201 public:
202 using Base::Base;
203
204 /// Return an instance of the float16 type.
205 static Float16Type get(MLIRContext *context);
206 };
207
getF16(MLIRContext * ctx)208 inline FloatType FloatType::getF16(MLIRContext *ctx) {
209 return Float16Type::get(ctx);
210 }
211
212 //===----------------------------------------------------------------------===//
213 // Float32Type
214
215 class Float32Type : public Type::TypeBase<Float32Type, FloatType, TypeStorage> {
216 public:
217 using Base::Base;
218
219 /// Return an instance of the float32 type.
220 static Float32Type get(MLIRContext *context);
221 };
222
getF32(MLIRContext * ctx)223 inline FloatType FloatType::getF32(MLIRContext *ctx) {
224 return Float32Type::get(ctx);
225 }
226
227 //===----------------------------------------------------------------------===//
228 // Float64Type
229
230 class Float64Type : public Type::TypeBase<Float64Type, FloatType, TypeStorage> {
231 public:
232 using Base::Base;
233
234 /// Return an instance of the float64 type.
235 static Float64Type get(MLIRContext *context);
236 };
237
getF64(MLIRContext * ctx)238 inline FloatType FloatType::getF64(MLIRContext *ctx) {
239 return Float64Type::get(ctx);
240 }
241
242 //===----------------------------------------------------------------------===//
243 // FunctionType
244 //===----------------------------------------------------------------------===//
245
246 /// Function types map from a list of inputs to a list of results.
247 class FunctionType
248 : public Type::TypeBase<FunctionType, Type, detail::FunctionTypeStorage> {
249 public:
250 using Base::Base;
251
252 static FunctionType get(TypeRange inputs, TypeRange results,
253 MLIRContext *context);
254
255 /// Input types.
256 unsigned getNumInputs() const;
getInput(unsigned i)257 Type getInput(unsigned i) const { return getInputs()[i]; }
258 ArrayRef<Type> getInputs() const;
259
260 /// Result types.
261 unsigned getNumResults() const;
getResult(unsigned i)262 Type getResult(unsigned i) const { return getResults()[i]; }
263 ArrayRef<Type> getResults() const;
264
265 /// Returns a new function type without the specified arguments and results.
266 FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
267 ArrayRef<unsigned> resultIndices);
268 };
269
270 //===----------------------------------------------------------------------===//
271 // NoneType
272 //===----------------------------------------------------------------------===//
273
274 /// NoneType is a unit type, i.e. a type with exactly one possible value, where
275 /// its value does not have a defined dynamic representation.
276 class NoneType : public Type::TypeBase<NoneType, Type, TypeStorage> {
277 public:
278 using Base::Base;
279
280 /// Get an instance of the NoneType.
281 static NoneType get(MLIRContext *context);
282 };
283
284 //===----------------------------------------------------------------------===//
285 // OpaqueType
286 //===----------------------------------------------------------------------===//
287
288 /// Opaque types represent types of non-registered dialects. These are types
289 /// represented in their raw string form, and can only usefully be tested for
290 /// type equality.
291 class OpaqueType
292 : public Type::TypeBase<OpaqueType, Type, detail::OpaqueTypeStorage> {
293 public:
294 using Base::Base;
295
296 /// Get or create a new OpaqueType with the provided dialect and string data.
297 static OpaqueType get(Identifier dialect, StringRef typeData,
298 MLIRContext *context);
299
300 /// Get or create a new OpaqueType with the provided dialect and string data.
301 /// If the given identifier is not a valid namespace for a dialect, then a
302 /// null type is returned.
303 static OpaqueType getChecked(Identifier dialect, StringRef typeData,
304 MLIRContext *context, Location location);
305
306 /// Returns the dialect namespace of the opaque type.
307 Identifier getDialectNamespace() const;
308
309 /// Returns the raw type data of the opaque type.
310 StringRef getTypeData() const;
311
312 /// Verify the construction of an opaque type.
313 static LogicalResult verifyConstructionInvariants(Location loc,
314 Identifier dialect,
315 StringRef typeData);
316 };
317
318 //===----------------------------------------------------------------------===//
319 // ShapedType
320 //===----------------------------------------------------------------------===//
321
322 /// This is a common base class between Vector, UnrankedTensor, RankedTensor,
323 /// and MemRef types because they share behavior and semantics around shape,
324 /// rank, and fixed element type. Any type with these semantics should inherit
325 /// from ShapedType.
326 class ShapedType : public Type {
327 public:
328 using ImplType = detail::ShapedTypeStorage;
329 using Type::Type;
330
331 // TODO: merge these two special values in a single one used everywhere.
332 // Unfortunately, uses of `-1` have crept deep into the codebase now and are
333 // hard to track.
334 static constexpr int64_t kDynamicSize = -1;
335 static constexpr int64_t kDynamicStrideOrOffset =
336 std::numeric_limits<int64_t>::min();
337
338 /// Return the element type.
339 Type getElementType() const;
340
341 /// If an element type is an integer or a float, return its width. Otherwise,
342 /// abort.
343 unsigned getElementTypeBitWidth() const;
344
345 /// If it has static shape, return the number of elements. Otherwise, abort.
346 int64_t getNumElements() const;
347
348 /// If this is a ranked type, return the rank. Otherwise, abort.
349 int64_t getRank() const;
350
351 /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
352 /// have a rank, while unranked tensors do not.
353 bool hasRank() const;
354
355 /// If this is a ranked type, return the shape. Otherwise, abort.
356 ArrayRef<int64_t> getShape() const;
357
358 /// If this is unranked type or any dimension has unknown size (<0), it
359 /// doesn't have static shape. If all dimensions have known size (>= 0), it
360 /// has static shape.
361 bool hasStaticShape() const;
362
363 /// If this has a static shape and the shape is equal to `shape` return true.
364 bool hasStaticShape(ArrayRef<int64_t> shape) const;
365
366 /// If this is a ranked type, return the number of dimensions with dynamic
367 /// size. Otherwise, abort.
368 int64_t getNumDynamicDims() const;
369
370 /// If this is ranked type, return the size of the specified dimension.
371 /// Otherwise, abort.
372 int64_t getDimSize(unsigned idx) const;
373
374 /// Returns true if this dimension has a dynamic size (for ranked types);
375 /// aborts for unranked types.
376 bool isDynamicDim(unsigned idx) const;
377
378 /// Returns the position of the dynamic dimension relative to just the dynamic
379 /// dimensions, given its `index` within the shape.
380 unsigned getDynamicDimIndex(unsigned index) const;
381
382 /// Get the total amount of bits occupied by a value of this type. This does
383 /// not take into account any memory layout or widening constraints, e.g. a
384 /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
385 /// it will likely be stored as in a 4xi64 vector register. Fail an assertion
386 /// if the size cannot be computed statically, i.e. if the type has a dynamic
387 /// shape or if its elemental type does not have a known bit width.
388 int64_t getSizeInBits() const;
389
390 /// Methods for support type inquiry through isa, cast, and dyn_cast.
391 static bool classof(Type type);
392
393 /// Whether the given dimension size indicates a dynamic dimension.
isDynamic(int64_t dSize)394 static constexpr bool isDynamic(int64_t dSize) {
395 return dSize == kDynamicSize;
396 }
isDynamicStrideOrOffset(int64_t dStrideOrOffset)397 static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
398 return dStrideOrOffset == kDynamicStrideOrOffset;
399 }
400 };
401
402 //===----------------------------------------------------------------------===//
403 // VectorType
404 //===----------------------------------------------------------------------===//
405
406 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
407 /// known constant shape with one or more dimension.
408 class VectorType
409 : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
410 public:
411 using Base::Base;
412
413 /// Get or create a new VectorType of the provided shape and element type.
414 /// Assumes the arguments define a well-formed VectorType.
415 static VectorType get(ArrayRef<int64_t> shape, Type elementType);
416
417 /// Get or create a new VectorType of the provided shape and element type
418 /// declared at the given, potentially unknown, location. If the VectorType
419 /// defined by the arguments would be ill-formed, emit errors and return
420 /// nullptr-wrapping type.
421 static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
422 Location location);
423
424 /// Verify the construction of a vector type.
425 static LogicalResult verifyConstructionInvariants(Location loc,
426 ArrayRef<int64_t> shape,
427 Type elementType);
428
429 /// Returns true of the given type can be used as an element of a vector type.
430 /// In particular, vectors can consist of integer or float primitives.
isValidElementType(Type t)431 static bool isValidElementType(Type t) {
432 return t.isa<IntegerType, FloatType>();
433 }
434
435 ArrayRef<int64_t> getShape() const;
436 };
437
438 //===----------------------------------------------------------------------===//
439 // TensorType
440 //===----------------------------------------------------------------------===//
441
442 /// Tensor types represent multi-dimensional arrays, and have two variants:
443 /// RankedTensorType and UnrankedTensorType.
444 class TensorType : public ShapedType {
445 public:
446 using ShapedType::ShapedType;
447
448 /// Return true if the specified element type is ok in a tensor.
449 static bool isValidElementType(Type type);
450
451 /// Methods for support type inquiry through isa, cast, and dyn_cast.
452 static bool classof(Type type);
453 };
454
455 //===----------------------------------------------------------------------===//
456 // RankedTensorType
457
458 /// Ranked tensor types represent multi-dimensional arrays that have a shape
459 /// with a fixed number of dimensions. Each shape element can be a non-negative
460 /// integer or unknown (represented by -1).
461 class RankedTensorType
462 : public Type::TypeBase<RankedTensorType, TensorType,
463 detail::RankedTensorTypeStorage> {
464 public:
465 using Base::Base;
466
467 /// Get or create a new RankedTensorType of the provided shape and element
468 /// type. Assumes the arguments define a well-formed type.
469 static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
470
471 /// Get or create a new RankedTensorType of the provided shape and element
472 /// type declared at the given, potentially unknown, location. If the
473 /// RankedTensorType defined by the arguments would be ill-formed, emit errors
474 /// and return a nullptr-wrapping type.
475 static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
476 Location location);
477
478 /// Verify the construction of a ranked tensor type.
479 static LogicalResult verifyConstructionInvariants(Location loc,
480 ArrayRef<int64_t> shape,
481 Type elementType);
482
483 ArrayRef<int64_t> getShape() const;
484 };
485
486 //===----------------------------------------------------------------------===//
487 // UnrankedTensorType
488
489 /// Unranked tensor types represent multi-dimensional arrays that have an
490 /// unknown shape.
491 class UnrankedTensorType
492 : public Type::TypeBase<UnrankedTensorType, TensorType,
493 detail::UnrankedTensorTypeStorage> {
494 public:
495 using Base::Base;
496
497 /// Get or create a new UnrankedTensorType of the provided shape and element
498 /// type. Assumes the arguments define a well-formed type.
499 static UnrankedTensorType get(Type elementType);
500
501 /// Get or create a new UnrankedTensorType of the provided shape and element
502 /// type declared at the given, potentially unknown, location. If the
503 /// UnrankedTensorType defined by the arguments would be ill-formed, emit
504 /// errors and return a nullptr-wrapping type.
505 static UnrankedTensorType getChecked(Type elementType, Location location);
506
507 /// Verify the construction of a unranked tensor type.
508 static LogicalResult verifyConstructionInvariants(Location loc,
509 Type elementType);
510
getShape()511 ArrayRef<int64_t> getShape() const { return llvm::None; }
512 };
513
514 //===----------------------------------------------------------------------===//
515 // BaseMemRefType
516 //===----------------------------------------------------------------------===//
517
518 /// Base MemRef for Ranked and Unranked variants
519 class BaseMemRefType : public ShapedType {
520 public:
521 using ImplType = detail::BaseMemRefTypeStorage;
522 using ShapedType::ShapedType;
523
524 /// Return true if the specified element type is ok in a memref.
isValidElementType(Type type)525 static bool isValidElementType(Type type) {
526 return type.isIntOrIndexOrFloat() || type.isa<VectorType, ComplexType>();
527 }
528
529 /// Methods for support type inquiry through isa, cast, and dyn_cast.
530 static bool classof(Type type);
531
532 /// Returns the memory space in which data referred to by this memref resides.
533 unsigned getMemorySpace() const;
534 };
535
536 //===----------------------------------------------------------------------===//
537 // MemRefType
538
539 /// MemRef types represent a region of memory that have a shape with a fixed
540 /// number of dimensions. Each shape element can be a non-negative integer or
541 /// unknown (represented by -1). MemRef types also have an affine map
542 /// composition, represented as an array AffineMap pointers.
543 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
544 detail::MemRefTypeStorage> {
545 public:
546 /// This is a builder type that keeps local references to arguments. Arguments
547 /// that are passed into the builder must out-live the builder.
548 class Builder {
549 public:
550 // Build from another MemRefType.
Builder(MemRefType other)551 explicit Builder(MemRefType other)
552 : shape(other.getShape()), elementType(other.getElementType()),
553 affineMaps(other.getAffineMaps()),
554 memorySpace(other.getMemorySpace()) {}
555
556 // Build from scratch.
Builder(ArrayRef<int64_t> shape,Type elementType)557 Builder(ArrayRef<int64_t> shape, Type elementType)
558 : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {
559 }
560
setShape(ArrayRef<int64_t> newShape)561 Builder &setShape(ArrayRef<int64_t> newShape) {
562 shape = newShape;
563 return *this;
564 }
565
setElementType(Type newElementType)566 Builder &setElementType(Type newElementType) {
567 elementType = newElementType;
568 return *this;
569 }
570
setAffineMaps(ArrayRef<AffineMap> newAffineMaps)571 Builder &setAffineMaps(ArrayRef<AffineMap> newAffineMaps) {
572 affineMaps = newAffineMaps;
573 return *this;
574 }
575
setMemorySpace(unsigned newMemorySpace)576 Builder &setMemorySpace(unsigned newMemorySpace) {
577 memorySpace = newMemorySpace;
578 return *this;
579 }
580
MemRefType()581 operator MemRefType() {
582 return MemRefType::get(shape, elementType, affineMaps, memorySpace);
583 }
584
585 private:
586 ArrayRef<int64_t> shape;
587 Type elementType;
588 ArrayRef<AffineMap> affineMaps;
589 unsigned memorySpace;
590 };
591
592 using Base::Base;
593
594 /// Get or create a new MemRefType based on shape, element type, affine
595 /// map composition, and memory space. Assumes the arguments define a
596 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
597 /// construction failures.
598 static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
599 ArrayRef<AffineMap> affineMapComposition = {},
600 unsigned memorySpace = 0);
601
602 /// Get or create a new MemRefType based on shape, element type, affine
603 /// map composition, and memory space declared at the given location.
604 /// If the location is unknown, the last argument should be an instance of
605 /// UnknownLoc. If the MemRefType defined by the arguments would be
606 /// ill-formed, emits errors (to the handler registered with the context or to
607 /// the error stream) and returns nullptr.
608 static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
609 ArrayRef<AffineMap> affineMapComposition,
610 unsigned memorySpace, Location location);
611
612 ArrayRef<int64_t> getShape() const;
613
614 /// Returns an array of affine map pointers representing the memref affine
615 /// map composition.
616 ArrayRef<AffineMap> getAffineMaps() const;
617
618 // TODO: merge these two special values in a single one used everywhere.
619 // Unfortunately, uses of `-1` have crept deep into the codebase now and are
620 // hard to track.
getDynamicStrideOrOffset()621 static int64_t getDynamicStrideOrOffset() {
622 return ShapedType::kDynamicStrideOrOffset;
623 }
624
625 private:
626 /// Get or create a new MemRefType defined by the arguments. If the resulting
627 /// type would be ill-formed, return nullptr. If the location is provided,
628 /// emit detailed error messages.
629 static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
630 ArrayRef<AffineMap> affineMapComposition,
631 unsigned memorySpace, Optional<Location> location);
632 using Base::getImpl;
633 };
634
635 //===----------------------------------------------------------------------===//
636 // UnrankedMemRefType
637
638 /// Unranked MemRef type represent multi-dimensional MemRefs that
639 /// have an unknown rank.
640 class UnrankedMemRefType
641 : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
642 detail::UnrankedMemRefTypeStorage> {
643 public:
644 using Base::Base;
645
646 /// Get or create a new UnrankedMemRefType of the provided element
647 /// type and memory space
648 static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
649
650 /// Get or create a new UnrankedMemRefType of the provided element
651 /// type and memory space declared at the given, potentially unknown,
652 /// location. If the UnrankedMemRefType defined by the arguments would be
653 /// ill-formed, emit errors and return a nullptr-wrapping type.
654 static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
655 Location location);
656
657 /// Verify the construction of a unranked memref type.
658 static LogicalResult verifyConstructionInvariants(Location loc,
659 Type elementType,
660 unsigned memorySpace);
661
getShape()662 ArrayRef<int64_t> getShape() const { return llvm::None; }
663 };
664
665 //===----------------------------------------------------------------------===//
666 // TupleType
667 //===----------------------------------------------------------------------===//
668
669 /// Tuple types represent a collection of other types. Note: This type merely
670 /// provides a common mechanism for representing tuples in MLIR. It is up to
671 /// dialect authors to provides operations for manipulating them, e.g.
672 /// extract_tuple_element. When possible, users should prefer multi-result
673 /// operations in the place of tuples.
674 class TupleType
675 : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
676 public:
677 using Base::Base;
678
679 /// Get or create a new TupleType with the provided element types. Assumes the
680 /// arguments define a well-formed type.
681 static TupleType get(TypeRange elementTypes, MLIRContext *context);
682
683 /// Get or create an empty tuple type.
684 static TupleType get(MLIRContext *context);
685
686 /// Return the elements types for this tuple.
687 ArrayRef<Type> getTypes() const;
688
689 /// Accumulate the types contained in this tuple and tuples nested within it.
690 /// Note that this only flattens nested tuples, not any other container type,
691 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
692 /// (i32, tensor<i32>, f32, i64)
693 void getFlattenedTypes(SmallVectorImpl<Type> &types);
694
695 /// Return the number of held types.
696 size_t size() const;
697
698 /// Iterate over the held elements.
699 using iterator = ArrayRef<Type>::iterator;
begin()700 iterator begin() const { return getTypes().begin(); }
end()701 iterator end() const { return getTypes().end(); }
702
703 /// Return the element type at index 'index'.
getType(size_t index)704 Type getType(size_t index) const {
705 assert(index < size() && "invalid index for tuple type");
706 return getTypes()[index];
707 }
708 };
709
710 //===----------------------------------------------------------------------===//
711 // Deferred Method Definitions
712 //===----------------------------------------------------------------------===//
713
classof(Type type)714 inline bool BaseMemRefType::classof(Type type) {
715 return type.isa<MemRefType, UnrankedMemRefType>();
716 }
717
classof(Type type)718 inline bool FloatType::classof(Type type) {
719 return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type>();
720 }
721
classof(Type type)722 inline bool ShapedType::classof(Type type) {
723 return type.isa<RankedTensorType, VectorType, UnrankedTensorType,
724 UnrankedMemRefType, MemRefType>();
725 }
726
classof(Type type)727 inline bool TensorType::classof(Type type) {
728 return type.isa<RankedTensorType, UnrankedTensorType>();
729 }
730
731 //===----------------------------------------------------------------------===//
732 // Type Utilities
733 //===----------------------------------------------------------------------===//
734
735 /// Returns the strides of the MemRef if the layout map is in strided form.
736 /// MemRefs with layout maps in strided form include:
737 /// 1. empty or identity layout map, in which case the stride information is
738 /// the canonical form computed from sizes;
739 /// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
740 /// where K and ki's are constants or symbols.
741 ///
742 /// A stride specification is a list of integer values that are either static
743 /// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
744 /// distance in the number of elements between successive entries along a
745 /// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
746 /// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
747 /// elements in which the distance between two consecutive elements along the
748 /// outer dimension is `1` and the distance between two consecutive elements
749 /// along the inner dimension is `64`.
750 ///
751 /// Returns whether a simple strided form can be extracted from the composition
752 /// of the layout map.
753 ///
754 /// The convention is that the strides for dimensions d0, .. dn appear in
755 /// order to make indexing intuitive into the result.
756 LogicalResult getStridesAndOffset(MemRefType t,
757 SmallVectorImpl<int64_t> &strides,
758 int64_t &offset);
759 LogicalResult getStridesAndOffset(MemRefType t,
760 SmallVectorImpl<AffineExpr> &strides,
761 AffineExpr &offset);
762
763 /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
764 /// represents a dynamic value), return the single result AffineMap which
765 /// represents the linearized strided layout map. Dimensions correspond to the
766 /// offset followed by the strides in order. Symbols are inserted for each
767 /// dynamic dimension in order. A stride cannot take value `0`.
768 ///
769 /// Examples:
770 /// =========
771 ///
772 /// 1. For offset: 0 strides: ?, ?, 1 return
773 /// (i, j, k)[M, N]->(M * i + N * j + k)
774 ///
775 /// 2. For offset: 3 strides: 32, ?, 16 return
776 /// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
777 ///
778 /// 3. For offset: ? strides: ?, ?, ? return
779 /// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
780 AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
781 MLIRContext *context);
782
783 /// Return a version of `t` with identity layout if it can be determined
784 /// statically that the layout is the canonical contiguous strided layout.
785 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
786 /// `t` with simplified layout.
787 MemRefType canonicalizeStridedLayout(MemRefType t);
788
789 /// Return a version of `t` with a layout that has all dynamic offset and
790 /// strides. This is used to erase the static layout.
791 MemRefType eraseStridedLayout(MemRefType t);
792
793 /// Given MemRef `sizes` that are either static or dynamic, returns the
794 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
795 /// once a dynamic dimension is encountered, all canonical strides become
796 /// dynamic and need to be encoded with a different symbol.
797 /// For canonical strides expressions, the offset is always 0 and and fastest
798 /// varying stride is always `1`.
799 ///
800 /// Examples:
801 /// - memref<3x4x5xf32> has canonical stride expression
802 /// `20*exprs[0] + 5*exprs[1] + exprs[2]`.
803 /// - memref<3x?x5xf32> has canonical stride expression
804 /// `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
805 /// - memref<3x4x?xf32> has canonical stride expression
806 /// `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
807 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
808 ArrayRef<AffineExpr> exprs,
809 MLIRContext *context);
810
811 /// Return the result of makeCanonicalStrudedLayoutExpr for the common case
812 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
813 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
814 MLIRContext *context);
815
816 /// Return true if the layout for `t` is compatible with strided semantics.
817 bool isStrided(MemRefType t);
818
819 } // end namespace mlir
820
821 #endif // MLIR_IR_BUILTINTYPES_H
822