1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/Utils/Utils.h"
18 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
19 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20 #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <type_traits>
30
31 #define DEBUG_TYPE "linalg-transforms"
32
33 using namespace mlir;
34 using namespace mlir::edsc;
35 using namespace mlir::edsc::intrinsics;
36 using namespace mlir::linalg;
37
38 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
39
40 //===----------------------------------------------------------------------===//
41 // Transformations exposed as rewrite patterns.
42 //===----------------------------------------------------------------------===//
43 // Marker used as attribute name in generated Linalg rewriting transformations.
44 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
45 "__internal_linalg_transform__";
46
LinalgMarker(ArrayRef<Identifier> matchDisjunction,Optional<Identifier> replacement)47 mlir::linalg::LinalgMarker::LinalgMarker(ArrayRef<Identifier> matchDisjunction,
48 Optional<Identifier> replacement)
49 : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
50 replacement(replacement) {}
51
52 LogicalResult
checkAndNotify(PatternRewriter & rewriter,Operation * op) const53 mlir::linalg::LinalgMarker::checkAndNotify(PatternRewriter &rewriter,
54 Operation *op) const {
55 auto attr = op->template getAttrOfType<StringAttr>(
56 LinalgTransforms::kLinalgTransformMarker);
57
58 if (!attr) {
59 // 1. Has no marker case and matchDisjunction is empty.
60 if (matchDisjunction.empty())
61 return success();
62
63 // 2. Has no marker but was expecting a marker.
64 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
65 diag << " does not have any marker from list: ";
66 interleaveComma(matchDisjunction, diag);
67 });
68 }
69
70 // 4. Match explicit marker.
71 for (auto marker : matchDisjunction)
72 if (attr.getValue() == marker)
73 return success();
74
75 // 5. Fail to match.
76 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
77 diag << " does not have any marker from list: ";
78 interleaveComma(matchDisjunction, diag);
79 });
80 }
81
replaceLinalgMarker(PatternRewriter & rewriter,Operation * op) const82 void mlir::linalg::LinalgMarker::replaceLinalgMarker(PatternRewriter &rewriter,
83 Operation *op) const {
84 if (replacement.hasValue())
85 op->setAttr(LinalgTransforms::kLinalgTransformMarker,
86 rewriter.getStringAttr(replacement.getValue()));
87 else
88 op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
89 rewriter.getContext()));
90 }
91
92 LinalgTilingOptions &
setTileSizes(ArrayRef<int64_t> ts)93 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
94 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
95 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
96 OpBuilder::InsertionGuard guard(b);
97 b.setInsertionPointToStart(
98 &op->getParentOfType<FuncOp>().getBody().front());
99 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
100 Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
101 return v;
102 }));
103 };
104 return *this;
105 }
106
107 /// Linalg base tiling pattern.
LinalgBaseTilingPattern(StringRef opName,MLIRContext * context,LinalgTilingOptions options,LinalgMarker marker,PatternBenefit benefit)108 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
109 StringRef opName, MLIRContext *context, LinalgTilingOptions options,
110 LinalgMarker marker, PatternBenefit benefit)
111 : RewritePattern(opName, {}, benefit, context), marker(marker),
112 options(options) {}
113
LinalgBaseTilingPattern(LinalgTilingOptions options,LinalgMarker marker,PatternBenefit benefit)114 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
115 LinalgTilingOptions options, LinalgMarker marker, PatternBenefit benefit)
116 : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker),
117 options(options) {}
118
matchAndRewriteBase(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<Value> & tensorResults) const119 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
120 Operation *op, PatternRewriter &rewriter,
121 SmallVectorImpl<Value> &tensorResults) const {
122 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
123 if (!linalgOp)
124 return failure();
125 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
126 return failure();
127
128 // If LinalgOp has results, they must all be tied to init tensors.
129 // We enforce this to ensure all tiled ops have been rewritten in
130 // "init tensor" form. This ensures tiling has anchor values into which to
131 // subtensor / subtensor_insert. Otherwise tiling would need to allocate which
132 // is not acceptable.
133 // This would not be the case with a special terminator op that generates the
134 // whole tensor (instead of inserting a subtensor). But the generator-based
135 // abstraction has other issues.
136 if (linalgOp.getNumInitTensors() != linalgOp->getNumResults())
137 return failure();
138
139 Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
140
141 if (!res)
142 return failure();
143
144 // Return relevant information to derived pattern.
145 tensorResults = res->tensorResults;
146
147 // New marker if specified.
148 marker.replaceLinalgMarker(rewriter, res->op.getOperation());
149 return success();
150 }
151
LinalgBaseTileAndFusePattern(StringRef opName,MLIRContext * context,const LinalgDependenceGraph & dependenceGraph,LinalgTilingOptions tilingOptions,LinalgFusionOptions fusionOptions,LinalgMarker marker,LinalgMarker fusedOpMarker,LinalgMarker originalOpMarker,PatternBenefit benefit)152 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
153 StringRef opName, MLIRContext *context,
154 const LinalgDependenceGraph &dependenceGraph,
155 LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
156 LinalgMarker marker, LinalgMarker fusedOpMarker,
157 LinalgMarker originalOpMarker, PatternBenefit benefit)
158 : RewritePattern(opName, {}, benefit, context),
159 dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
160 fusionOptions(fusionOptions), marker(marker),
161 fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
162
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const163 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
164 Operation *op, PatternRewriter &rewriter) const {
165 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
166 if (!linalgOp)
167 return failure();
168 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
169 return failure();
170 if (!linalgOp.hasBufferSemantics())
171 return failure();
172
173 DenseSet<Operation *> producers;
174 producers.insert(linalgOp);
175 for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) {
176 if (!fusionOptions.indicesToFuse.count(
177 dependence.indexingOpView.operandIndex))
178 continue;
179 if (isa<LinalgOp>(dependence.dependentOpView.op))
180 producers.insert(dependence.dependentOpView.op);
181 }
182
183 SmallVector<LinalgOp, 1> fusionOps;
184 for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
185 ++it) {
186 auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
187 if (producerLinalgOp && producers.count(producerLinalgOp))
188 fusionOps.push_back(producerLinalgOp);
189 }
190 fusionOps.push_back(linalgOp);
191
192 SmallVector<Value, 4> tileSizes =
193 tilingOptions.tileSizeComputationFunction(rewriter, op);
194 LinalgTilingOptions instanceTilingOptions = tilingOptions;
195 instanceTilingOptions.setTileSizes(tileSizes);
196 Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
197 rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
198 if (!tiledAndFusedOps)
199 return failure();
200
201 // Tile the unfused loops;
202 SmallVector<Value, 4> unfusedLoopTileSizes;
203 Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
204 for (auto tileSize : enumerate(tileSizes)) {
205 if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
206 unfusedLoopTileSizes.push_back(zero);
207 else
208 unfusedLoopTileSizes.push_back(tileSize.value());
209 }
210 // Tile the loop only if there is a non-zero tile size.
211 if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
212 unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
213 if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
214 if (auto cst = val.getDefiningOp<ConstantIndexOp>())
215 return cst.getValue() != 0;
216 return true;
217 })) {
218 LinalgTilingOptions unfusedTilingOptions = tilingOptions;
219 unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
220 Optional<TiledLinalgOp> unfusedTiledOp =
221 tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
222 if (!unfusedTiledOp)
223 return failure();
224 rewriter.eraseOp(tiledAndFusedOps->op);
225 tiledAndFusedOps->op = unfusedTiledOp->op;
226 }
227
228 marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
229 for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
230 fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
231 }
232 for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
233 originalOpMarker.replaceLinalgMarker(rewriter,
234 origProducerOp.getOperation());
235 }
236 rewriter.updateRootInPlace(
237 op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
238 return success();
239 }
240
241 /// Linalg base interchange pattern.
LinalgBaseInterchangePattern(StringRef opName,MLIRContext * context,ArrayRef<unsigned> interchangeVector,LinalgMarker marker,PatternBenefit benefit)242 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
243 StringRef opName, MLIRContext *context,
244 ArrayRef<unsigned> interchangeVector, LinalgMarker marker,
245 PatternBenefit benefit)
246 : RewritePattern(opName, {}, benefit, context), marker(marker),
247 interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
248
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const249 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
250 Operation *op, PatternRewriter &rewriter) const {
251 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
252 if (!linalgOp)
253 return failure();
254 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
255 return failure();
256 if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
257 return failure();
258
259 // TODO: figure out how this interplays with named ops. In particular this
260 // should break the named op property.
261 rewriter.updateRootInPlace(op, [&]() {
262 interchange(linalgOp, interchangeVector);
263 // New marker if specified.
264 marker.replaceLinalgMarker(rewriter, op);
265 });
266 return success();
267 }
268
LinalgBasePromotionPattern(StringRef opName,MLIRContext * context,LinalgPromotionOptions options,LinalgMarker marker,PatternBenefit benefit)269 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
270 StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
271 LinalgMarker marker, PatternBenefit benefit)
272 : RewritePattern(opName, {}, benefit, context), marker(marker),
273 options(options) {}
274
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const275 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
276 Operation *op, PatternRewriter &rewriter) const {
277 if (failed(marker.checkAndNotify(rewriter, op)))
278 return failure();
279 if (failed(promoteSubviewsPrecondition(op, options)))
280 return failure();
281
282 // TODO: We cannot use root update here. This pattern is creating other ops,
283 // so if the promotion fails, those need to be cleaned up, which doesnt seem
284 // to be happening here. So to fail properly, we should be cloning the op and
285 // deleting the previous op. This needs more investigation.
286 rewriter.startRootUpdate(op);
287 Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
288 if (!promotedOp) {
289 rewriter.cancelRootUpdate(op);
290 return op->emitError("subview promotion failed");
291 }
292 rewriter.finalizeRootUpdate(op);
293 marker.replaceLinalgMarker(rewriter, op);
294 return success();
295 }
296
LinalgBaseVectorizationPattern(StringRef opName,MLIRContext * context,LinalgMarker marker,PatternBenefit benefit)297 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
298 StringRef opName, MLIRContext *context, LinalgMarker marker,
299 PatternBenefit benefit)
300 : RewritePattern(opName, {}, benefit, context), marker(marker) {}
301
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const302 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
303 Operation *op, PatternRewriter &rewriter) const {
304 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
305 if (!linalgOp)
306 return failure();
307 if (failed(marker.checkAndNotify(rewriter, linalgOp)))
308 return failure();
309 if (failed(vectorizeLinalgOpPrecondition(op)))
310 return failure();
311 vectorizeLinalgOp(rewriter, op);
312 rewriter.eraseOp(op);
313 return success();
314 }
315
applyStagedPatterns(Operation * op,ArrayRef<FrozenRewritePatternList> stage1Patterns,const FrozenRewritePatternList & stage2Patterns,function_ref<LogicalResult (Operation *)> stage3Lambda)316 LogicalResult mlir::linalg::applyStagedPatterns(
317 Operation *op, ArrayRef<FrozenRewritePatternList> stage1Patterns,
318 const FrozenRewritePatternList &stage2Patterns,
319 function_ref<LogicalResult(Operation *)> stage3Lambda) {
320 unsigned iteration = 0;
321 (void)iteration;
322 for (const auto &patterns : stage1Patterns) {
323 LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
324 << *op);
325 if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
326 LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
327 return failure();
328 }
329 LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
330 << *op);
331 if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
332 LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
333 return failure();
334 }
335 LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
336 << *op);
337 if (stage3Lambda) {
338 if (failed(stage3Lambda(op)))
339 return failure();
340 LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
341 << *op);
342 }
343 }
344 return success();
345 }
346
347 /// Traverse `e` and return an AffineExpr where all occurrences of `dim` have
348 /// been replaced by either:
349 /// - `min` if `positivePath` is true when we reach an occurrence of `dim`
350 /// - `max` if `positivePath` is true when we reach an occurrence of `dim`
351 /// `positivePath` is negated each time we hit a multiplicative or divisive
352 /// binary op with a constant negative coefficient.
substWithMin(AffineExpr e,AffineExpr dim,AffineExpr min,AffineExpr max,bool positivePath=true)353 static AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
354 AffineExpr max, bool positivePath = true) {
355 if (e == dim)
356 return positivePath ? min : max;
357 if (auto bin = e.dyn_cast<AffineBinaryOpExpr>()) {
358 AffineExpr lhs = bin.getLHS();
359 AffineExpr rhs = bin.getRHS();
360 if (bin.getKind() == mlir::AffineExprKind::Add)
361 return substWithMin(lhs, dim, min, max, positivePath) +
362 substWithMin(rhs, dim, min, max, positivePath);
363
364 auto c1 = bin.getLHS().dyn_cast<AffineConstantExpr>();
365 auto c2 = bin.getRHS().dyn_cast<AffineConstantExpr>();
366 if (c1 && c1.getValue() < 0)
367 return getAffineBinaryOpExpr(
368 bin.getKind(), c1, substWithMin(rhs, dim, min, max, !positivePath));
369 if (c2 && c2.getValue() < 0)
370 return getAffineBinaryOpExpr(
371 bin.getKind(), substWithMin(lhs, dim, min, max, !positivePath), c2);
372 return getAffineBinaryOpExpr(
373 bin.getKind(), substWithMin(lhs, dim, min, max, positivePath),
374 substWithMin(rhs, dim, min, max, positivePath));
375 }
376 return e;
377 }
378
379 /// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
380 /// `ubVal` to `dims` and `stepVal` to `symbols`.
381 /// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
382 /// with positions matching the newly appended values. Substitute occurrences of
383 /// `dimExpr` by either the min expression (i.e. `%lb`) or the max expression
384 /// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`), depending on whether
385 /// the induction variable is used with a positive or negative coefficient.
substituteLoopInExpr(AffineExpr expr,AffineExpr dimExpr,Value lbVal,Value ubVal,Value stepVal,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & symbols)386 static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
387 Value lbVal, Value ubVal, Value stepVal,
388 SmallVectorImpl<Value> &dims,
389 SmallVectorImpl<Value> &symbols) {
390 MLIRContext *ctx = lbVal.getContext();
391 AffineExpr lb = getAffineDimExpr(dims.size(), ctx);
392 dims.push_back(lbVal);
393 AffineExpr ub = getAffineDimExpr(dims.size(), ctx);
394 dims.push_back(ubVal);
395 AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx);
396 symbols.push_back(stepVal);
397 LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
398 AffineExpr ee = substWithMin(expr, dimExpr, lb,
399 lb + step * ((ub - 1) - lb).floorDiv(step));
400 LLVM_DEBUG(DBGS() << "After: " << expr << "\n");
401 return ee;
402 }
403
404 /// Traverse the `dims` and substitute known min or max expressions in place of
405 /// induction variables in `exprs`.
substitute(AffineMap map,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & symbols)406 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
407 SmallVectorImpl<Value> &symbols) {
408 auto exprs = llvm::to_vector<4>(map.getResults());
409 for (AffineExpr &expr : exprs) {
410 bool substituted = true;
411 while (substituted) {
412 substituted = false;
413 for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
414 Value dim = dims[dimIdx];
415 AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
416 LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
417 AffineExpr substitutedExpr;
418 if (auto forOp = scf::getForInductionVarOwner(dim))
419 substitutedExpr = substituteLoopInExpr(
420 expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
421 forOp.step(), dims, symbols);
422
423 if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
424 for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
425 ++idx)
426 substitutedExpr = substituteLoopInExpr(
427 expr, dimExpr, parallelForOp.lowerBound()[idx],
428 parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
429 dims, symbols);
430
431 if (!substitutedExpr)
432 continue;
433
434 substituted = (substitutedExpr != expr);
435 expr = substitutedExpr;
436 }
437 }
438
439 // Cleanup and simplify the results.
440 // This needs to happen outside of the loop iterating on dims.size() since
441 // it modifies dims.
442 SmallVector<Value, 4> operands(dims.begin(), dims.end());
443 operands.append(symbols.begin(), symbols.end());
444 auto map = AffineMap::get(dims.size(), symbols.size(), exprs,
445 exprs.front().getContext());
446
447 LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
448
449 // Pull in affine.apply operations and compose them fully into the
450 // result.
451 fullyComposeAffineMapAndOperands(&map, &operands);
452 canonicalizeMapAndOperands(&map, &operands);
453 map = simplifyAffineMap(map);
454 // Assign the results.
455 exprs.assign(map.getResults().begin(), map.getResults().end());
456 dims.assign(operands.begin(), operands.begin() + map.getNumDims());
457 symbols.assign(operands.begin() + map.getNumDims(), operands.end());
458
459 LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n");
460 }
461
462 assert(!exprs.empty() && "Unexpected empty exprs");
463 return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
464 }
465
matchAndRewrite(AffineMinOp minOp,PatternRewriter & rewriter) const466 LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
467 AffineMinOp minOp, PatternRewriter &rewriter) const {
468 LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
469 << "\n");
470
471 SmallVector<Value, 4> dims(minOp.getDimOperands()),
472 symbols(minOp.getSymbolOperands());
473 AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
474
475 LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
476
477 // Check whether any of the expressions, when subtracted from all other
478 // expressions, produces only >= 0 constants. If so, it is the min.
479 for (auto e : minOp.getAffineMap().getResults()) {
480 LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n");
481 if (!e.isSymbolicOrConstant())
482 continue;
483
484 auto isNonPositive = [](AffineExpr e) {
485 if (auto cst = e.dyn_cast<AffineConstantExpr>())
486 return cst.getValue() < 0;
487 return true;
488 };
489
490 // Build the subMap and check everything is statically known to be
491 // positive.
492 SmallVector<AffineExpr, 4> subExprs;
493 subExprs.reserve(map.getNumResults());
494 for (auto ee : map.getResults())
495 subExprs.push_back(ee - e);
496 MLIRContext *ctx = minOp.getContext();
497 AffineMap subMap = simplifyAffineMap(
498 AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx));
499 LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n");
500 if (llvm::any_of(subMap.getResults(), isNonPositive))
501 continue;
502
503 // Static min found.
504 if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
505 rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
506 } else {
507 auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
508 SmallVector<Value, 4> resultOperands = dims;
509 resultOperands.append(symbols.begin(), symbols.end());
510 canonicalizeMapAndOperands(&resultMap, &resultOperands);
511 resultMap = simplifyAffineMap(resultMap);
512 rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
513 resultOperands);
514 }
515 return success();
516 }
517
518 return failure();
519 }
520