1 //===- Interchange.cpp - Linalg interchange transformation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg interchange transformation.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16 #include "mlir/Dialect/Linalg/Utils/Utils.h"
17 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
18 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
19 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <type_traits>
29
30 #define DEBUG_TYPE "linalg-interchange"
31
32 using namespace mlir;
33 using namespace mlir::linalg;
34
interchangeGenericLinalgOpPrecondition(Operation * op,ArrayRef<unsigned> interchangeVector)35 LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition(
36 Operation *op, ArrayRef<unsigned> interchangeVector) {
37 if (interchangeVector.empty())
38 return failure();
39 // Transformation applies to generic ops only.
40 if (!isa<GenericOp, IndexedGenericOp>(op))
41 return failure();
42 LinalgOp linOp = cast<LinalgOp>(op);
43 // Transformation applies to buffers only.
44 if (!linOp.hasBufferSemantics())
45 return failure();
46 // Permutation must be applicable.
47 if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size())
48 return failure();
49 // Permutation map must be invertible.
50 if (!inversePermutation(
51 AffineMap::getPermutationMap(interchangeVector, op->getContext())))
52 return failure();
53 return success();
54 }
55
interchange(LinalgOp op,ArrayRef<unsigned> interchangeVector)56 LinalgOp mlir::linalg::interchange(LinalgOp op,
57 ArrayRef<unsigned> interchangeVector) {
58 if (interchangeVector.empty())
59 return op;
60
61 MLIRContext *context = op.getContext();
62 auto permutationMap = inversePermutation(
63 AffineMap::getPermutationMap(interchangeVector, context));
64 assert(permutationMap && "expected permutation to be invertible");
65 SmallVector<Attribute, 4> newIndexingMaps;
66 auto indexingMaps = op.indexing_maps().getValue();
67 for (unsigned i = 0, e = op.getNumInputsAndOutputs(); i != e; ++i) {
68 AffineMap m = indexingMaps[i].cast<AffineMapAttr>().getValue();
69 if (!permutationMap.isEmpty())
70 m = m.compose(permutationMap);
71 newIndexingMaps.push_back(AffineMapAttr::get(m));
72 }
73 auto itTypes = op.iterator_types().getValue();
74 SmallVector<Attribute, 4> itTypesVector;
75 for (unsigned i = 0, e = itTypes.size(); i != e; ++i)
76 itTypesVector.push_back(itTypes[i]);
77 applyPermutationToVector(itTypesVector, interchangeVector);
78
79 op.setAttr(getIndexingMapsAttrName(),
80 ArrayAttr::get(newIndexingMaps, context));
81 op.setAttr(getIteratorTypesAttrName(),
82 ArrayAttr::get(itTypesVector, context));
83
84 return op;
85 }
86