1 //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/Twine.h"
19
20 using namespace mlir;
21 using namespace mlir::detail;
22
23 //===----------------------------------------------------------------------===//
24 /// ComplexType
25 //===----------------------------------------------------------------------===//
26
get(Type elementType)27 ComplexType ComplexType::get(Type elementType) {
28 return Base::get(elementType.getContext(), elementType);
29 }
30
getChecked(Type elementType,Location location)31 ComplexType ComplexType::getChecked(Type elementType, Location location) {
32 return Base::getChecked(location, elementType);
33 }
34
35 /// Verify the construction of an integer type.
verifyConstructionInvariants(Location loc,Type elementType)36 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
37 Type elementType) {
38 if (!elementType.isIntOrFloat())
39 return emitError(loc, "invalid element type for complex");
40 return success();
41 }
42
getElementType()43 Type ComplexType::getElementType() { return getImpl()->elementType; }
44
45 //===----------------------------------------------------------------------===//
46 // Integer Type
47 //===----------------------------------------------------------------------===//
48
49 // static constexpr must have a definition (until in C++17 and inline variable).
50 constexpr unsigned IntegerType::kMaxWidth;
51
52 /// Verify the construction of an integer type.
53 LogicalResult
verifyConstructionInvariants(Location loc,unsigned width,SignednessSemantics signedness)54 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
55 SignednessSemantics signedness) {
56 if (width > IntegerType::kMaxWidth) {
57 return emitError(loc) << "integer bitwidth is limited to "
58 << IntegerType::kMaxWidth << " bits";
59 }
60 return success();
61 }
62
getWidth() const63 unsigned IntegerType::getWidth() const { return getImpl()->width; }
64
getSignedness() const65 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
66 return getImpl()->signedness;
67 }
68
69 //===----------------------------------------------------------------------===//
70 // Float Type
71 //===----------------------------------------------------------------------===//
72
getWidth()73 unsigned FloatType::getWidth() {
74 if (isa<Float16Type, BFloat16Type>())
75 return 16;
76 if (isa<Float32Type>())
77 return 32;
78 if (isa<Float64Type>())
79 return 64;
80 llvm_unreachable("unexpected float type");
81 }
82
83 /// Returns the floating semantics for the given type.
getFloatSemantics()84 const llvm::fltSemantics &FloatType::getFloatSemantics() {
85 if (isa<BFloat16Type>())
86 return APFloat::BFloat();
87 if (isa<Float16Type>())
88 return APFloat::IEEEhalf();
89 if (isa<Float32Type>())
90 return APFloat::IEEEsingle();
91 if (isa<Float64Type>())
92 return APFloat::IEEEdouble();
93 llvm_unreachable("non-floating point type used");
94 }
95
96 //===----------------------------------------------------------------------===//
97 // FunctionType
98 //===----------------------------------------------------------------------===//
99
get(TypeRange inputs,TypeRange results,MLIRContext * context)100 FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
101 MLIRContext *context) {
102 return Base::get(context, inputs, results);
103 }
104
getNumInputs() const105 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
106
getInputs() const107 ArrayRef<Type> FunctionType::getInputs() const {
108 return getImpl()->getInputs();
109 }
110
getNumResults() const111 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
112
getResults() const113 ArrayRef<Type> FunctionType::getResults() const {
114 return getImpl()->getResults();
115 }
116
117 /// Helper to call a callback once on each index in the range
118 /// [0, `totalIndices`), *except* for the indices given in `indices`.
119 /// `indices` is allowed to have duplicates and can be in any order.
iterateIndicesExcept(unsigned totalIndices,ArrayRef<unsigned> indices,function_ref<void (unsigned)> callback)120 inline void iterateIndicesExcept(unsigned totalIndices,
121 ArrayRef<unsigned> indices,
122 function_ref<void(unsigned)> callback) {
123 llvm::BitVector skipIndices(totalIndices);
124 for (unsigned i : indices)
125 skipIndices.set(i);
126
127 for (unsigned i = 0; i < totalIndices; ++i)
128 if (!skipIndices.test(i))
129 callback(i);
130 }
131
132 /// Returns a new function type without the specified arguments and results.
133 FunctionType
getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,ArrayRef<unsigned> resultIndices)134 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
135 ArrayRef<unsigned> resultIndices) {
136 ArrayRef<Type> newInputTypes = getInputs();
137 SmallVector<Type, 4> newInputTypesBuffer;
138 if (!argIndices.empty()) {
139 unsigned originalNumArgs = getNumInputs();
140 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
141 newInputTypesBuffer.emplace_back(getInput(i));
142 });
143 newInputTypes = newInputTypesBuffer;
144 }
145
146 ArrayRef<Type> newResultTypes = getResults();
147 SmallVector<Type, 4> newResultTypesBuffer;
148 if (!resultIndices.empty()) {
149 unsigned originalNumResults = getNumResults();
150 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
151 newResultTypesBuffer.emplace_back(getResult(i));
152 });
153 newResultTypes = newResultTypesBuffer;
154 }
155
156 return get(newInputTypes, newResultTypes, getContext());
157 }
158
159 //===----------------------------------------------------------------------===//
160 // OpaqueType
161 //===----------------------------------------------------------------------===//
162
get(Identifier dialect,StringRef typeData,MLIRContext * context)163 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
164 MLIRContext *context) {
165 return Base::get(context, dialect, typeData);
166 }
167
getChecked(Identifier dialect,StringRef typeData,MLIRContext * context,Location location)168 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
169 MLIRContext *context, Location location) {
170 return Base::getChecked(location, dialect, typeData);
171 }
172
173 /// Returns the dialect namespace of the opaque type.
getDialectNamespace() const174 Identifier OpaqueType::getDialectNamespace() const {
175 return getImpl()->dialectNamespace;
176 }
177
178 /// Returns the raw type data of the opaque type.
getTypeData() const179 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
180
181 /// Verify the construction of an opaque type.
verifyConstructionInvariants(Location loc,Identifier dialect,StringRef typeData)182 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
183 Identifier dialect,
184 StringRef typeData) {
185 if (!Dialect::isValidNamespace(dialect.strref()))
186 return emitError(loc, "invalid dialect namespace '") << dialect << "'";
187 return success();
188 }
189
190 //===----------------------------------------------------------------------===//
191 // ShapedType
192 //===----------------------------------------------------------------------===//
193 constexpr int64_t ShapedType::kDynamicSize;
194 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
195
getElementType() const196 Type ShapedType::getElementType() const {
197 return static_cast<ImplType *>(impl)->elementType;
198 }
199
getElementTypeBitWidth() const200 unsigned ShapedType::getElementTypeBitWidth() const {
201 return getElementType().getIntOrFloatBitWidth();
202 }
203
getNumElements() const204 int64_t ShapedType::getNumElements() const {
205 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
206 auto shape = getShape();
207 int64_t num = 1;
208 for (auto dim : shape)
209 num *= dim;
210 return num;
211 }
212
getRank() const213 int64_t ShapedType::getRank() const { return getShape().size(); }
214
hasRank() const215 bool ShapedType::hasRank() const {
216 return !isa<UnrankedMemRefType, UnrankedTensorType>();
217 }
218
getDimSize(unsigned idx) const219 int64_t ShapedType::getDimSize(unsigned idx) const {
220 assert(idx < getRank() && "invalid index for shaped type");
221 return getShape()[idx];
222 }
223
isDynamicDim(unsigned idx) const224 bool ShapedType::isDynamicDim(unsigned idx) const {
225 assert(idx < getRank() && "invalid index for shaped type");
226 return isDynamic(getShape()[idx]);
227 }
228
getDynamicDimIndex(unsigned index) const229 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
230 assert(index < getRank() && "invalid index");
231 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
232 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
233 }
234
235 /// Get the number of bits require to store a value of the given shaped type.
236 /// Compute the value recursively since tensors are allowed to have vectors as
237 /// elements.
getSizeInBits() const238 int64_t ShapedType::getSizeInBits() const {
239 assert(hasStaticShape() &&
240 "cannot get the bit size of an aggregate with a dynamic shape");
241
242 auto elementType = getElementType();
243 if (elementType.isIntOrFloat())
244 return elementType.getIntOrFloatBitWidth() * getNumElements();
245
246 if (auto complexType = elementType.dyn_cast<ComplexType>()) {
247 elementType = complexType.getElementType();
248 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
249 }
250
251 // Tensors can have vectors and other tensors as elements, other shaped types
252 // cannot.
253 assert(isa<TensorType>() && "unsupported element type");
254 assert((elementType.isa<VectorType, TensorType>()) &&
255 "unsupported tensor element type");
256 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
257 }
258
getShape() const259 ArrayRef<int64_t> ShapedType::getShape() const {
260 if (auto vectorType = dyn_cast<VectorType>())
261 return vectorType.getShape();
262 if (auto tensorType = dyn_cast<RankedTensorType>())
263 return tensorType.getShape();
264 return cast<MemRefType>().getShape();
265 }
266
getNumDynamicDims() const267 int64_t ShapedType::getNumDynamicDims() const {
268 return llvm::count_if(getShape(), isDynamic);
269 }
270
hasStaticShape() const271 bool ShapedType::hasStaticShape() const {
272 return hasRank() && llvm::none_of(getShape(), isDynamic);
273 }
274
hasStaticShape(ArrayRef<int64_t> shape) const275 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
276 return hasStaticShape() && getShape() == shape;
277 }
278
279 //===----------------------------------------------------------------------===//
280 // VectorType
281 //===----------------------------------------------------------------------===//
282
get(ArrayRef<int64_t> shape,Type elementType)283 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
284 return Base::get(elementType.getContext(), shape, elementType);
285 }
286
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)287 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
288 Location location) {
289 return Base::getChecked(location, shape, elementType);
290 }
291
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)292 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
293 ArrayRef<int64_t> shape,
294 Type elementType) {
295 if (shape.empty())
296 return emitError(loc, "vector types must have at least one dimension");
297
298 if (!isValidElementType(elementType))
299 return emitError(loc, "vector elements must be int or float type");
300
301 if (any_of(shape, [](int64_t i) { return i <= 0; }))
302 return emitError(loc, "vector types must have positive constant sizes");
303
304 return success();
305 }
306
getShape() const307 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
308
309 //===----------------------------------------------------------------------===//
310 // TensorType
311 //===----------------------------------------------------------------------===//
312
313 // Check if "elementType" can be an element type of a tensor. Emit errors if
314 // location is not nullptr. Returns failure if check failed.
checkTensorElementType(Location location,Type elementType)315 static LogicalResult checkTensorElementType(Location location,
316 Type elementType) {
317 if (!TensorType::isValidElementType(elementType))
318 return emitError(location, "invalid tensor element type: ") << elementType;
319 return success();
320 }
321
322 /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)323 bool TensorType::isValidElementType(Type type) {
324 // Note: Non standard/builtin types are allowed to exist within tensor
325 // types. Dialects are expected to verify that tensor types have a valid
326 // element type within that dialect.
327 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
328 IndexType>() ||
329 !type.getDialect().getNamespace().empty();
330 }
331
332 //===----------------------------------------------------------------------===//
333 // RankedTensorType
334 //===----------------------------------------------------------------------===//
335
get(ArrayRef<int64_t> shape,Type elementType)336 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
337 Type elementType) {
338 return Base::get(elementType.getContext(), shape, elementType);
339 }
340
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)341 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
342 Type elementType,
343 Location location) {
344 return Base::getChecked(location, shape, elementType);
345 }
346
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)347 LogicalResult RankedTensorType::verifyConstructionInvariants(
348 Location loc, ArrayRef<int64_t> shape, Type elementType) {
349 for (int64_t s : shape) {
350 if (s < -1)
351 return emitError(loc, "invalid tensor dimension size");
352 }
353 return checkTensorElementType(loc, elementType);
354 }
355
getShape() const356 ArrayRef<int64_t> RankedTensorType::getShape() const {
357 return getImpl()->getShape();
358 }
359
360 //===----------------------------------------------------------------------===//
361 // UnrankedTensorType
362 //===----------------------------------------------------------------------===//
363
get(Type elementType)364 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
365 return Base::get(elementType.getContext(), elementType);
366 }
367
getChecked(Type elementType,Location location)368 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
369 Location location) {
370 return Base::getChecked(location, elementType);
371 }
372
373 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType)374 UnrankedTensorType::verifyConstructionInvariants(Location loc,
375 Type elementType) {
376 return checkTensorElementType(loc, elementType);
377 }
378
379 //===----------------------------------------------------------------------===//
380 // BaseMemRefType
381 //===----------------------------------------------------------------------===//
382
getMemorySpace() const383 unsigned BaseMemRefType::getMemorySpace() const {
384 return static_cast<ImplType *>(impl)->memorySpace;
385 }
386
387 //===----------------------------------------------------------------------===//
388 // MemRefType
389 //===----------------------------------------------------------------------===//
390
391 /// Get or create a new MemRefType based on shape, element type, affine
392 /// map composition, and memory space. Assumes the arguments define a
393 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
394 /// construction failures.
get(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace)395 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
396 ArrayRef<AffineMap> affineMapComposition,
397 unsigned memorySpace) {
398 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
399 /*location=*/llvm::None);
400 assert(result && "Failed to construct instance of MemRefType.");
401 return result;
402 }
403
404 /// Get or create a new MemRefType based on shape, element type, affine
405 /// map composition, and memory space declared at the given location.
406 /// If the location is unknown, the last argument should be an instance of
407 /// UnknownLoc. If the MemRefType defined by the arguments would be
408 /// ill-formed, emits errors (to the handler registered with the context or to
409 /// the error stream) and returns nullptr.
getChecked(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Location location)410 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
411 ArrayRef<AffineMap> affineMapComposition,
412 unsigned memorySpace, Location location) {
413 return getImpl(shape, elementType, affineMapComposition, memorySpace,
414 location);
415 }
416
417 /// Get or create a new MemRefType defined by the arguments. If the resulting
418 /// type would be ill-formed, return nullptr. If the location is provided,
419 /// emit detailed error messages. To emit errors when the location is unknown,
420 /// pass in an instance of UnknownLoc.
getImpl(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Optional<Location> location)421 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
422 ArrayRef<AffineMap> affineMapComposition,
423 unsigned memorySpace,
424 Optional<Location> location) {
425 auto *context = elementType.getContext();
426
427 if (!BaseMemRefType::isValidElementType(elementType))
428 return emitOptionalError(location, "invalid memref element type"),
429 MemRefType();
430
431 for (int64_t s : shape) {
432 // Negative sizes are not allowed except for `-1` that means dynamic size.
433 if (s < -1)
434 return emitOptionalError(location, "invalid memref size"), MemRefType();
435 }
436
437 // Check that the structure of the composition is valid, i.e. that each
438 // subsequent affine map has as many inputs as the previous map has results.
439 // Take the dimensionality of the MemRef for the first map.
440 auto dim = shape.size();
441 unsigned i = 0;
442 for (const auto &affineMap : affineMapComposition) {
443 if (affineMap.getNumDims() != dim) {
444 if (location)
445 emitError(*location)
446 << "memref affine map dimension mismatch between "
447 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
448 << " and affine map" << i + 1 << ": " << dim
449 << " != " << affineMap.getNumDims();
450 return nullptr;
451 }
452
453 dim = affineMap.getNumResults();
454 ++i;
455 }
456
457 // Drop identity maps from the composition.
458 // This may lead to the composition becoming empty, which is interpreted as an
459 // implicit identity.
460 SmallVector<AffineMap, 2> cleanedAffineMapComposition;
461 for (const auto &map : affineMapComposition) {
462 if (map.isIdentity())
463 continue;
464 cleanedAffineMapComposition.push_back(map);
465 }
466
467 return Base::get(context, shape, elementType, cleanedAffineMapComposition,
468 memorySpace);
469 }
470
getShape() const471 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
472
getAffineMaps() const473 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
474 return getImpl()->getAffineMaps();
475 }
476
477 //===----------------------------------------------------------------------===//
478 // UnrankedMemRefType
479 //===----------------------------------------------------------------------===//
480
get(Type elementType,unsigned memorySpace)481 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
482 unsigned memorySpace) {
483 return Base::get(elementType.getContext(), elementType, memorySpace);
484 }
485
getChecked(Type elementType,unsigned memorySpace,Location location)486 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
487 unsigned memorySpace,
488 Location location) {
489 return Base::getChecked(location, elementType, memorySpace);
490 }
491
492 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType,unsigned memorySpace)493 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
494 unsigned memorySpace) {
495 if (!BaseMemRefType::isValidElementType(elementType))
496 return emitError(loc, "invalid memref element type");
497 return success();
498 }
499
500 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
501 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)502 static void extractStridesFromTerm(AffineExpr e,
503 AffineExpr multiplicativeFactor,
504 MutableArrayRef<AffineExpr> strides,
505 AffineExpr &offset) {
506 if (auto dim = e.dyn_cast<AffineDimExpr>())
507 strides[dim.getPosition()] =
508 strides[dim.getPosition()] + multiplicativeFactor;
509 else
510 offset = offset + e * multiplicativeFactor;
511 }
512
513 /// Takes a single AffineExpr `e` and populates the `strides` array with the
514 /// strides expressions for each dim position.
515 /// The convention is that the strides for dimensions d0, .. dn appear in
516 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)517 static LogicalResult extractStrides(AffineExpr e,
518 AffineExpr multiplicativeFactor,
519 MutableArrayRef<AffineExpr> strides,
520 AffineExpr &offset) {
521 auto bin = e.dyn_cast<AffineBinaryOpExpr>();
522 if (!bin) {
523 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
524 return success();
525 }
526
527 if (bin.getKind() == AffineExprKind::CeilDiv ||
528 bin.getKind() == AffineExprKind::FloorDiv ||
529 bin.getKind() == AffineExprKind::Mod)
530 return failure();
531
532 if (bin.getKind() == AffineExprKind::Mul) {
533 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
534 if (dim) {
535 strides[dim.getPosition()] =
536 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
537 return success();
538 }
539 // LHS and RHS may both contain complex expressions of dims. Try one path
540 // and if it fails try the other. This is guaranteed to succeed because
541 // only one path may have a `dim`, otherwise this is not an AffineExpr in
542 // the first place.
543 if (bin.getLHS().isSymbolicOrConstant())
544 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
545 strides, offset);
546 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
547 strides, offset);
548 }
549
550 if (bin.getKind() == AffineExprKind::Add) {
551 auto res1 =
552 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
553 auto res2 =
554 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
555 return success(succeeded(res1) && succeeded(res2));
556 }
557
558 llvm_unreachable("unexpected binary operation");
559 }
560
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)561 LogicalResult mlir::getStridesAndOffset(MemRefType t,
562 SmallVectorImpl<AffineExpr> &strides,
563 AffineExpr &offset) {
564 auto affineMaps = t.getAffineMaps();
565 // For now strides are only computed on a single affine map with a single
566 // result (i.e. the closed subset of linearization maps that are compatible
567 // with striding semantics).
568 // TODO: support more forms on a per-need basis.
569 if (affineMaps.size() > 1)
570 return failure();
571 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
572 return failure();
573
574 auto zero = getAffineConstantExpr(0, t.getContext());
575 auto one = getAffineConstantExpr(1, t.getContext());
576 offset = zero;
577 strides.assign(t.getRank(), zero);
578
579 AffineMap m;
580 if (!affineMaps.empty()) {
581 m = affineMaps.front();
582 assert(!m.isIdentity() && "unexpected identity map");
583 }
584
585 // Canonical case for empty map.
586 if (!m) {
587 // 0-D corner case, offset is already 0.
588 if (t.getRank() == 0)
589 return success();
590 auto stridedExpr =
591 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
592 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
593 return success();
594 assert(false && "unexpected failure: extract strides in canonical layout");
595 }
596
597 // Non-canonical case requires more work.
598 auto stridedExpr =
599 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
600 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
601 offset = AffineExpr();
602 strides.clear();
603 return failure();
604 }
605
606 // Simplify results to allow folding to constants and simple checks.
607 unsigned numDims = m.getNumDims();
608 unsigned numSymbols = m.getNumSymbols();
609 offset = simplifyAffineExpr(offset, numDims, numSymbols);
610 for (auto &stride : strides)
611 stride = simplifyAffineExpr(stride, numDims, numSymbols);
612
613 /// In practice, a strided memref must be internally non-aliasing. Test
614 /// against 0 as a proxy.
615 /// TODO: static cases can have more advanced checks.
616 /// TODO: dynamic cases would require a way to compare symbolic
617 /// expressions and would probably need an affine set context propagated
618 /// everywhere.
619 if (llvm::any_of(strides, [](AffineExpr e) {
620 return e == getAffineConstantExpr(0, e.getContext());
621 })) {
622 offset = AffineExpr();
623 strides.clear();
624 return failure();
625 }
626
627 return success();
628 }
629
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)630 LogicalResult mlir::getStridesAndOffset(MemRefType t,
631 SmallVectorImpl<int64_t> &strides,
632 int64_t &offset) {
633 AffineExpr offsetExpr;
634 SmallVector<AffineExpr, 4> strideExprs;
635 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
636 return failure();
637 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
638 offset = cst.getValue();
639 else
640 offset = ShapedType::kDynamicStrideOrOffset;
641 for (auto e : strideExprs) {
642 if (auto c = e.dyn_cast<AffineConstantExpr>())
643 strides.push_back(c.getValue());
644 else
645 strides.push_back(ShapedType::kDynamicStrideOrOffset);
646 }
647 return success();
648 }
649
650 //===----------------------------------------------------------------------===//
651 /// TupleType
652 //===----------------------------------------------------------------------===//
653
654 /// Get or create a new TupleType with the provided element types. Assumes the
655 /// arguments define a well-formed type.
get(TypeRange elementTypes,MLIRContext * context)656 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
657 return Base::get(context, elementTypes);
658 }
659
660 /// Get or create an empty tuple type.
get(MLIRContext * context)661 TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
662
663 /// Return the elements types for this tuple.
getTypes() const664 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
665
666 /// Accumulate the types contained in this tuple and tuples nested within it.
667 /// Note that this only flattens nested tuples, not any other container type,
668 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
669 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)670 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
671 for (Type type : getTypes()) {
672 if (auto nestedTuple = type.dyn_cast<TupleType>())
673 nestedTuple.getFlattenedTypes(types);
674 else
675 types.push_back(type);
676 }
677 }
678
679 /// Return the number of element types.
size() const680 size_t TupleType::size() const { return getImpl()->size(); }
681
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)682 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
683 int64_t offset,
684 MLIRContext *context) {
685 AffineExpr expr;
686 unsigned nSymbols = 0;
687
688 // AffineExpr for offset.
689 // Static case.
690 if (offset != MemRefType::getDynamicStrideOrOffset()) {
691 auto cst = getAffineConstantExpr(offset, context);
692 expr = cst;
693 } else {
694 // Dynamic case, new symbol for the offset.
695 auto sym = getAffineSymbolExpr(nSymbols++, context);
696 expr = sym;
697 }
698
699 // AffineExpr for strides.
700 for (auto en : llvm::enumerate(strides)) {
701 auto dim = en.index();
702 auto stride = en.value();
703 assert(stride != 0 && "Invalid stride specification");
704 auto d = getAffineDimExpr(dim, context);
705 AffineExpr mult;
706 // Static case.
707 if (stride != MemRefType::getDynamicStrideOrOffset())
708 mult = getAffineConstantExpr(stride, context);
709 else
710 // Dynamic case, new symbol for each new stride.
711 mult = getAffineSymbolExpr(nSymbols++, context);
712 expr = expr + d * mult;
713 }
714
715 return AffineMap::get(strides.size(), nSymbols, expr);
716 }
717
718 /// Return a version of `t` with identity layout if it can be determined
719 /// statically that the layout is the canonical contiguous strided layout.
720 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
721 /// `t` with simplified layout.
722 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)723 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
724 auto affineMaps = t.getAffineMaps();
725 // Already in canonical form.
726 if (affineMaps.empty())
727 return t;
728
729 // Can't reduce to canonical identity form, return in canonical form.
730 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
731 return t;
732
733 // If the canonical strided layout for the sizes of `t` is equal to the
734 // simplified layout of `t` we can just return an empty layout. Otherwise,
735 // just simplify the existing layout.
736 AffineExpr expr =
737 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
738 auto m = affineMaps[0];
739 auto simplifiedLayoutExpr =
740 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
741 if (expr != simplifiedLayoutExpr)
742 return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
743 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
744 return MemRefType::Builder(t).setAffineMaps({});
745 }
746
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)747 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
748 ArrayRef<AffineExpr> exprs,
749 MLIRContext *context) {
750 // Size 0 corner case is useful for canonicalizations.
751 if (llvm::is_contained(sizes, 0))
752 return getAffineConstantExpr(0, context);
753
754 auto maps = AffineMap::inferFromExprList(exprs);
755 assert(!maps.empty() && "Expected one non-empty map");
756 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
757
758 AffineExpr expr;
759 bool dynamicPoisonBit = false;
760 int64_t runningSize = 1;
761 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
762 int64_t size = std::get<1>(en);
763 // Degenerate case, no size =-> no stride
764 if (size == 0)
765 continue;
766 AffineExpr dimExpr = std::get<0>(en);
767 AffineExpr stride = dynamicPoisonBit
768 ? getAffineSymbolExpr(nSymbols++, context)
769 : getAffineConstantExpr(runningSize, context);
770 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
771 if (size > 0)
772 runningSize *= size;
773 else
774 dynamicPoisonBit = true;
775 }
776 return simplifyAffineExpr(expr, numDims, nSymbols);
777 }
778
779 /// Return a version of `t` with a layout that has all dynamic offset and
780 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)781 MemRefType mlir::eraseStridedLayout(MemRefType t) {
782 auto val = ShapedType::kDynamicStrideOrOffset;
783 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
784 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
785 }
786
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)787 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
788 MLIRContext *context) {
789 SmallVector<AffineExpr, 4> exprs;
790 exprs.reserve(sizes.size());
791 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
792 exprs.push_back(getAffineDimExpr(dim, context));
793 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
794 }
795
796 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)797 bool mlir::isStrided(MemRefType t) {
798 int64_t offset;
799 SmallVector<int64_t, 4> stridesAndOffset;
800 auto res = getStridesAndOffset(t, stridesAndOffset, offset);
801 return succeeded(res);
802 }
803