1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file lowers affine constructs (If and For statements, AffineApply
10 // operations) within a function into their standard If and For equivalent ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
15
16 #include "../PassDetail.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExprVisitor.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/Passes.h"
29
30 using namespace mlir;
31 using namespace mlir::vector;
32
33 namespace {
34 /// Visit affine expressions recursively and build the sequence of operations
35 /// that correspond to it. Visitation functions return an Value of the
36 /// expression subtree they visited or `nullptr` on error.
37 class AffineApplyExpander
38 : public AffineExprVisitor<AffineApplyExpander, Value> {
39 public:
40 /// This internal class expects arguments to be non-null, checks must be
41 /// performed at the call site.
AffineApplyExpander(OpBuilder & builder,ValueRange dimValues,ValueRange symbolValues,Location loc)42 AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
43 ValueRange symbolValues, Location loc)
44 : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
45 loc(loc) {}
46
buildBinaryExpr(AffineBinaryOpExpr expr)47 template <typename OpTy> Value buildBinaryExpr(AffineBinaryOpExpr expr) {
48 auto lhs = visit(expr.getLHS());
49 auto rhs = visit(expr.getRHS());
50 if (!lhs || !rhs)
51 return nullptr;
52 auto op = builder.create<OpTy>(loc, lhs, rhs);
53 return op.getResult();
54 }
55
visitAddExpr(AffineBinaryOpExpr expr)56 Value visitAddExpr(AffineBinaryOpExpr expr) {
57 return buildBinaryExpr<AddIOp>(expr);
58 }
59
visitMulExpr(AffineBinaryOpExpr expr)60 Value visitMulExpr(AffineBinaryOpExpr expr) {
61 return buildBinaryExpr<MulIOp>(expr);
62 }
63
64 /// Euclidean modulo operation: negative RHS is not allowed.
65 /// Remainder of the euclidean integer division is always non-negative.
66 ///
67 /// Implemented as
68 ///
69 /// a mod b =
70 /// let remainder = srem a, b;
71 /// negative = a < 0 in
72 /// select negative, remainder + b, remainder.
visitModExpr(AffineBinaryOpExpr expr)73 Value visitModExpr(AffineBinaryOpExpr expr) {
74 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
75 if (!rhsConst) {
76 emitError(
77 loc,
78 "semi-affine expressions (modulo by non-const) are not supported");
79 return nullptr;
80 }
81 if (rhsConst.getValue() <= 0) {
82 emitError(loc, "modulo by non-positive value is not supported");
83 return nullptr;
84 }
85
86 auto lhs = visit(expr.getLHS());
87 auto rhs = visit(expr.getRHS());
88 assert(lhs && rhs && "unexpected affine expr lowering failure");
89
90 Value remainder = builder.create<SignedRemIOp>(loc, lhs, rhs);
91 Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
92 Value isRemainderNegative =
93 builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
94 Value correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
95 Value result = builder.create<SelectOp>(loc, isRemainderNegative,
96 correctedRemainder, remainder);
97 return result;
98 }
99
100 /// Floor division operation (rounds towards negative infinity).
101 ///
102 /// For positive divisors, it can be implemented without branching and with a
103 /// single division operation as
104 ///
105 /// a floordiv b =
106 /// let negative = a < 0 in
107 /// let absolute = negative ? -a - 1 : a in
108 /// let quotient = absolute / b in
109 /// negative ? -quotient - 1 : quotient
visitFloorDivExpr(AffineBinaryOpExpr expr)110 Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
111 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
112 if (!rhsConst) {
113 emitError(
114 loc,
115 "semi-affine expressions (division by non-const) are not supported");
116 return nullptr;
117 }
118 if (rhsConst.getValue() <= 0) {
119 emitError(loc, "division by non-positive value is not supported");
120 return nullptr;
121 }
122
123 auto lhs = visit(expr.getLHS());
124 auto rhs = visit(expr.getRHS());
125 assert(lhs && rhs && "unexpected affine expr lowering failure");
126
127 Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
128 Value noneCst = builder.create<ConstantIndexOp>(loc, -1);
129 Value negative =
130 builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
131 Value negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
132 Value dividend =
133 builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
134 Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
135 Value correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient);
136 Value result =
137 builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
138 return result;
139 }
140
141 /// Ceiling division operation (rounds towards positive infinity).
142 ///
143 /// For positive divisors, it can be implemented without branching and with a
144 /// single division operation as
145 ///
146 /// a ceildiv b =
147 /// let negative = a <= 0 in
148 /// let absolute = negative ? -a : a - 1 in
149 /// let quotient = absolute / b in
150 /// negative ? -quotient : quotient + 1
visitCeilDivExpr(AffineBinaryOpExpr expr)151 Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
152 auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
153 if (!rhsConst) {
154 emitError(loc) << "semi-affine expressions (division by non-const) are "
155 "not supported";
156 return nullptr;
157 }
158 if (rhsConst.getValue() <= 0) {
159 emitError(loc, "division by non-positive value is not supported");
160 return nullptr;
161 }
162 auto lhs = visit(expr.getLHS());
163 auto rhs = visit(expr.getRHS());
164 assert(lhs && rhs && "unexpected affine expr lowering failure");
165
166 Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
167 Value oneCst = builder.create<ConstantIndexOp>(loc, 1);
168 Value nonPositive =
169 builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
170 Value negated = builder.create<SubIOp>(loc, zeroCst, lhs);
171 Value decremented = builder.create<SubIOp>(loc, lhs, oneCst);
172 Value dividend =
173 builder.create<SelectOp>(loc, nonPositive, negated, decremented);
174 Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
175 Value negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient);
176 Value incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst);
177 Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
178 incrementedQuotient);
179 return result;
180 }
181
visitConstantExpr(AffineConstantExpr expr)182 Value visitConstantExpr(AffineConstantExpr expr) {
183 auto valueAttr =
184 builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
185 auto op =
186 builder.create<ConstantOp>(loc, builder.getIndexType(), valueAttr);
187 return op.getResult();
188 }
189
visitDimExpr(AffineDimExpr expr)190 Value visitDimExpr(AffineDimExpr expr) {
191 assert(expr.getPosition() < dimValues.size() &&
192 "affine dim position out of range");
193 return dimValues[expr.getPosition()];
194 }
195
visitSymbolExpr(AffineSymbolExpr expr)196 Value visitSymbolExpr(AffineSymbolExpr expr) {
197 assert(expr.getPosition() < symbolValues.size() &&
198 "symbol dim position out of range");
199 return symbolValues[expr.getPosition()];
200 }
201
202 private:
203 OpBuilder &builder;
204 ValueRange dimValues;
205 ValueRange symbolValues;
206
207 Location loc;
208 };
209 } // namespace
210
211 /// Create a sequence of operations that implement the `expr` applied to the
212 /// given dimension and symbol values.
expandAffineExpr(OpBuilder & builder,Location loc,AffineExpr expr,ValueRange dimValues,ValueRange symbolValues)213 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
214 AffineExpr expr, ValueRange dimValues,
215 ValueRange symbolValues) {
216 return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
217 }
218
219 /// Create a sequence of operations that implement the `affineMap` applied to
220 /// the given `operands` (as it it were an AffineApplyOp).
expandAffineMap(OpBuilder & builder,Location loc,AffineMap affineMap,ValueRange operands)221 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder,
222 Location loc,
223 AffineMap affineMap,
224 ValueRange operands) {
225 auto numDims = affineMap.getNumDims();
226 auto expanded = llvm::to_vector<8>(
227 llvm::map_range(affineMap.getResults(),
228 [numDims, &builder, loc, operands](AffineExpr expr) {
229 return expandAffineExpr(builder, loc, expr,
230 operands.take_front(numDims),
231 operands.drop_front(numDims));
232 }));
233 if (llvm::all_of(expanded, [](Value v) { return v; }))
234 return expanded;
235 return None;
236 }
237
238 /// Given a range of values, emit the code that reduces them with "min" or "max"
239 /// depending on the provided comparison predicate. The predicate defines which
240 /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
241 /// `cmpi` operation followed by the `select` operation:
242 ///
243 /// %cond = cmpi "predicate" %v0, %v1
244 /// %result = select %cond, %v0, %v1
245 ///
246 /// Multiple values are scanned in a linear sequence. This creates a data
247 /// dependences that wouldn't exist in a tree reduction, but is easier to
248 /// recognize as a reduction by the subsequent passes.
buildMinMaxReductionSeq(Location loc,CmpIPredicate predicate,ValueRange values,OpBuilder & builder)249 static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
250 ValueRange values, OpBuilder &builder) {
251 assert(!llvm::empty(values) && "empty min/max chain");
252
253 auto valueIt = values.begin();
254 Value value = *valueIt++;
255 for (; valueIt != values.end(); ++valueIt) {
256 auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
257 value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
258 }
259
260 return value;
261 }
262
263 /// Emit instructions that correspond to computing the maximum value among the
264 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMax(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)265 static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
266 ValueRange operands) {
267 if (auto values = expandAffineMap(builder, loc, map, operands))
268 return buildMinMaxReductionSeq(loc, CmpIPredicate::sgt, *values, builder);
269 return nullptr;
270 }
271
272 /// Emit instructions that correspond to computing the minimum value among the
273 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMin(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)274 static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
275 ValueRange operands) {
276 if (auto values = expandAffineMap(builder, loc, map, operands))
277 return buildMinMaxReductionSeq(loc, CmpIPredicate::slt, *values, builder);
278 return nullptr;
279 }
280
281 /// Emit instructions that correspond to the affine map in the upper bound
282 /// applied to the respective operands, and compute the minimum value across
283 /// the results.
lowerAffineUpperBound(AffineForOp op,OpBuilder & builder)284 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
285 return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
286 op.getUpperBoundOperands());
287 }
288
289 /// Emit instructions that correspond to the affine map in the lower bound
290 /// applied to the respective operands, and compute the maximum value across
291 /// the results.
lowerAffineLowerBound(AffineForOp op,OpBuilder & builder)292 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
293 return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
294 op.getLowerBoundOperands());
295 }
296
297 namespace {
298 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
299 public:
300 using OpRewritePattern<AffineMinOp>::OpRewritePattern;
301
matchAndRewrite(AffineMinOp op,PatternRewriter & rewriter) const302 LogicalResult matchAndRewrite(AffineMinOp op,
303 PatternRewriter &rewriter) const override {
304 Value reduced =
305 lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
306 if (!reduced)
307 return failure();
308
309 rewriter.replaceOp(op, reduced);
310 return success();
311 }
312 };
313
314 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
315 public:
316 using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
317
matchAndRewrite(AffineMaxOp op,PatternRewriter & rewriter) const318 LogicalResult matchAndRewrite(AffineMaxOp op,
319 PatternRewriter &rewriter) const override {
320 Value reduced =
321 lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
322 if (!reduced)
323 return failure();
324
325 rewriter.replaceOp(op, reduced);
326 return success();
327 }
328 };
329
330 /// Affine yields ops are removed.
331 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
332 public:
333 using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
334
matchAndRewrite(AffineYieldOp op,PatternRewriter & rewriter) const335 LogicalResult matchAndRewrite(AffineYieldOp op,
336 PatternRewriter &rewriter) const override {
337 rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
338 return success();
339 }
340 };
341
342 class AffineForLowering : public OpRewritePattern<AffineForOp> {
343 public:
344 using OpRewritePattern<AffineForOp>::OpRewritePattern;
345
matchAndRewrite(AffineForOp op,PatternRewriter & rewriter) const346 LogicalResult matchAndRewrite(AffineForOp op,
347 PatternRewriter &rewriter) const override {
348 Location loc = op.getLoc();
349 Value lowerBound = lowerAffineLowerBound(op, rewriter);
350 Value upperBound = lowerAffineUpperBound(op, rewriter);
351 Value step = rewriter.create<ConstantIndexOp>(loc, op.getStep());
352 auto f = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
353 rewriter.eraseBlock(f.getBody());
354 rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
355 rewriter.eraseOp(op);
356 return success();
357 }
358 };
359
360 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
361 /// operation.
362 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
363 public:
364 using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
365
matchAndRewrite(AffineParallelOp op,PatternRewriter & rewriter) const366 LogicalResult matchAndRewrite(AffineParallelOp op,
367 PatternRewriter &rewriter) const override {
368 Location loc = op.getLoc();
369 SmallVector<Value, 8> steps;
370 SmallVector<Value, 8> upperBoundTuple;
371 SmallVector<Value, 8> lowerBoundTuple;
372 // Finding lower and upper bound by expanding the map expression.
373 // Checking if expandAffineMap is not giving NULL.
374 Optional<SmallVector<Value, 8>> upperBound = expandAffineMap(
375 rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands());
376 Optional<SmallVector<Value, 8>> lowerBound = expandAffineMap(
377 rewriter, loc, op.lowerBoundsMap(), op.getLowerBoundsOperands());
378 if (!lowerBound || !upperBound)
379 return failure();
380 upperBoundTuple = *upperBound;
381 lowerBoundTuple = *lowerBound;
382 steps.reserve(op.steps().size());
383 for (Attribute step : op.steps())
384 steps.push_back(rewriter.create<ConstantIndexOp>(
385 loc, step.cast<IntegerAttr>().getInt()));
386 // Creating empty scf.parallel op body with appropriate bounds.
387 auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
388 upperBoundTuple, steps);
389 rewriter.eraseBlock(parallelOp.getBody());
390 rewriter.inlineRegionBefore(op.region(), parallelOp.region(),
391 parallelOp.region().end());
392 rewriter.eraseOp(op);
393 return success();
394 }
395 };
396
397 class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
398 public:
399 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
400
matchAndRewrite(AffineIfOp op,PatternRewriter & rewriter) const401 LogicalResult matchAndRewrite(AffineIfOp op,
402 PatternRewriter &rewriter) const override {
403 auto loc = op.getLoc();
404
405 // Now we just have to handle the condition logic.
406 auto integerSet = op.getIntegerSet();
407 Value zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0);
408 SmallVector<Value, 8> operands(op.getOperands());
409 auto operandsRef = llvm::makeArrayRef(operands);
410
411 // Calculate cond as a conjunction without short-circuiting.
412 Value cond = nullptr;
413 for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
414 AffineExpr constraintExpr = integerSet.getConstraint(i);
415 bool isEquality = integerSet.isEq(i);
416
417 // Build and apply an affine expression
418 auto numDims = integerSet.getNumDims();
419 Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
420 operandsRef.take_front(numDims),
421 operandsRef.drop_front(numDims));
422 if (!affResult)
423 return failure();
424 auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
425 Value cmpVal =
426 rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
427 cond =
428 cond ? rewriter.create<AndOp>(loc, cond, cmpVal).getResult() : cmpVal;
429 }
430 cond = cond ? cond
431 : rewriter.create<ConstantIntOp>(loc, /*value=*/1, /*width=*/1);
432
433 bool hasElseRegion = !op.elseRegion().empty();
434 auto ifOp = rewriter.create<scf::IfOp>(loc, cond, hasElseRegion);
435 rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back());
436 rewriter.eraseBlock(&ifOp.thenRegion().back());
437 if (hasElseRegion) {
438 rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back());
439 rewriter.eraseBlock(&ifOp.elseRegion().back());
440 }
441
442 // Ok, we're done!
443 rewriter.eraseOp(op);
444 return success();
445 }
446 };
447
448 /// Convert an "affine.apply" operation into a sequence of arithmetic
449 /// operations using the StandardOps dialect.
450 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
451 public:
452 using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
453
matchAndRewrite(AffineApplyOp op,PatternRewriter & rewriter) const454 LogicalResult matchAndRewrite(AffineApplyOp op,
455 PatternRewriter &rewriter) const override {
456 auto maybeExpandedMap =
457 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
458 llvm::to_vector<8>(op.getOperands()));
459 if (!maybeExpandedMap)
460 return failure();
461 rewriter.replaceOp(op, *maybeExpandedMap);
462 return success();
463 }
464 };
465
466 /// Apply the affine map from an 'affine.load' operation to its operands, and
467 /// feed the results to a newly created 'std.load' operation (which replaces the
468 /// original 'affine.load').
469 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
470 public:
471 using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
472
matchAndRewrite(AffineLoadOp op,PatternRewriter & rewriter) const473 LogicalResult matchAndRewrite(AffineLoadOp op,
474 PatternRewriter &rewriter) const override {
475 // Expand affine map from 'affineLoadOp'.
476 SmallVector<Value, 8> indices(op.getMapOperands());
477 auto resultOperands =
478 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
479 if (!resultOperands)
480 return failure();
481
482 // Build std.load memref[expandedMap.results].
483 rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
484 return success();
485 }
486 };
487
488 /// Apply the affine map from an 'affine.prefetch' operation to its operands,
489 /// and feed the results to a newly created 'std.prefetch' operation (which
490 /// replaces the original 'affine.prefetch').
491 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
492 public:
493 using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
494
matchAndRewrite(AffinePrefetchOp op,PatternRewriter & rewriter) const495 LogicalResult matchAndRewrite(AffinePrefetchOp op,
496 PatternRewriter &rewriter) const override {
497 // Expand affine map from 'affinePrefetchOp'.
498 SmallVector<Value, 8> indices(op.getMapOperands());
499 auto resultOperands =
500 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
501 if (!resultOperands)
502 return failure();
503
504 // Build std.prefetch memref[expandedMap.results].
505 rewriter.replaceOpWithNewOp<PrefetchOp>(op, op.memref(), *resultOperands,
506 op.isWrite(), op.localityHint(),
507 op.isDataCache());
508 return success();
509 }
510 };
511
512 /// Apply the affine map from an 'affine.store' operation to its operands, and
513 /// feed the results to a newly created 'std.store' operation (which replaces
514 /// the original 'affine.store').
515 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
516 public:
517 using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
518
matchAndRewrite(AffineStoreOp op,PatternRewriter & rewriter) const519 LogicalResult matchAndRewrite(AffineStoreOp op,
520 PatternRewriter &rewriter) const override {
521 // Expand affine map from 'affineStoreOp'.
522 SmallVector<Value, 8> indices(op.getMapOperands());
523 auto maybeExpandedMap =
524 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
525 if (!maybeExpandedMap)
526 return failure();
527
528 // Build std.store valueToStore, memref[expandedMap.results].
529 rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
530 op.getMemRef(), *maybeExpandedMap);
531 return success();
532 }
533 };
534
535 /// Apply the affine maps from an 'affine.dma_start' operation to each of their
536 /// respective map operands, and feed the results to a newly created
537 /// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
538 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
539 public:
540 using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
541
matchAndRewrite(AffineDmaStartOp op,PatternRewriter & rewriter) const542 LogicalResult matchAndRewrite(AffineDmaStartOp op,
543 PatternRewriter &rewriter) const override {
544 SmallVector<Value, 8> operands(op.getOperands());
545 auto operandsRef = llvm::makeArrayRef(operands);
546
547 // Expand affine map for DMA source memref.
548 auto maybeExpandedSrcMap = expandAffineMap(
549 rewriter, op.getLoc(), op.getSrcMap(),
550 operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
551 if (!maybeExpandedSrcMap)
552 return failure();
553 // Expand affine map for DMA destination memref.
554 auto maybeExpandedDstMap = expandAffineMap(
555 rewriter, op.getLoc(), op.getDstMap(),
556 operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
557 if (!maybeExpandedDstMap)
558 return failure();
559 // Expand affine map for DMA tag memref.
560 auto maybeExpandedTagMap = expandAffineMap(
561 rewriter, op.getLoc(), op.getTagMap(),
562 operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
563 if (!maybeExpandedTagMap)
564 return failure();
565
566 // Build std.dma_start operation with affine map results.
567 rewriter.replaceOpWithNewOp<DmaStartOp>(
568 op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
569 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
570 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
571 return success();
572 }
573 };
574
575 /// Apply the affine map from an 'affine.dma_wait' operation tag memref,
576 /// and feed the results to a newly created 'std.dma_wait' operation (which
577 /// replaces the original 'affine.dma_wait').
578 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
579 public:
580 using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
581
matchAndRewrite(AffineDmaWaitOp op,PatternRewriter & rewriter) const582 LogicalResult matchAndRewrite(AffineDmaWaitOp op,
583 PatternRewriter &rewriter) const override {
584 // Expand affine map for DMA tag memref.
585 SmallVector<Value, 8> indices(op.getTagIndices());
586 auto maybeExpandedTagMap =
587 expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
588 if (!maybeExpandedTagMap)
589 return failure();
590
591 // Build std.dma_wait operation with affine map results.
592 rewriter.replaceOpWithNewOp<DmaWaitOp>(
593 op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
594 return success();
595 }
596 };
597
598 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
599 /// and feed the results to a newly created 'vector.transfer_read' operation
600 /// (which replaces the original 'affine.vector_load').
601 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
602 public:
603 using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
604
matchAndRewrite(AffineVectorLoadOp op,PatternRewriter & rewriter) const605 LogicalResult matchAndRewrite(AffineVectorLoadOp op,
606 PatternRewriter &rewriter) const override {
607 // Expand affine map from 'affineVectorLoadOp'.
608 SmallVector<Value, 8> indices(op.getMapOperands());
609 auto resultOperands =
610 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
611 if (!resultOperands)
612 return failure();
613
614 // Build vector.transfer_read memref[expandedMap.results].
615 rewriter.replaceOpWithNewOp<TransferReadOp>(
616 op, op.getVectorType(), op.getMemRef(), *resultOperands);
617 return success();
618 }
619 };
620
621 /// Apply the affine map from an 'affine.vector_store' operation to its
622 /// operands, and feed the results to a newly created 'vector.transfer_write'
623 /// operation (which replaces the original 'affine.vector_store').
624 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
625 public:
626 using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
627
matchAndRewrite(AffineVectorStoreOp op,PatternRewriter & rewriter) const628 LogicalResult matchAndRewrite(AffineVectorStoreOp op,
629 PatternRewriter &rewriter) const override {
630 // Expand affine map from 'affineVectorStoreOp'.
631 SmallVector<Value, 8> indices(op.getMapOperands());
632 auto maybeExpandedMap =
633 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
634 if (!maybeExpandedMap)
635 return failure();
636
637 rewriter.replaceOpWithNewOp<TransferWriteOp>(
638 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
639 return success();
640 }
641 };
642
643 } // end namespace
644
populateAffineToStdConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)645 void mlir::populateAffineToStdConversionPatterns(
646 OwningRewritePatternList &patterns, MLIRContext *ctx) {
647 // clang-format off
648 patterns.insert<
649 AffineApplyLowering,
650 AffineDmaStartLowering,
651 AffineDmaWaitLowering,
652 AffineLoadLowering,
653 AffineMinLowering,
654 AffineMaxLowering,
655 AffineParallelLowering,
656 AffinePrefetchLowering,
657 AffineStoreLowering,
658 AffineForLowering,
659 AffineIfLowering,
660 AffineYieldOpLowering>(ctx);
661 // clang-format on
662 }
663
populateAffineToVectorConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)664 void mlir::populateAffineToVectorConversionPatterns(
665 OwningRewritePatternList &patterns, MLIRContext *ctx) {
666 // clang-format off
667 patterns.insert<
668 AffineVectorLoadLowering,
669 AffineVectorStoreLowering>(ctx);
670 // clang-format on
671 }
672
673 namespace {
674 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
runOnOperation()675 void runOnOperation() override {
676 OwningRewritePatternList patterns;
677 populateAffineToStdConversionPatterns(patterns, &getContext());
678 populateAffineToVectorConversionPatterns(patterns, &getContext());
679 ConversionTarget target(getContext());
680 target
681 .addLegalDialect<scf::SCFDialect, StandardOpsDialect, VectorDialect>();
682 if (failed(applyPartialConversion(getOperation(), target,
683 std::move(patterns))))
684 signalPassFailure();
685 }
686 };
687 } // namespace
688
689 /// Lowers If and For operations within a function into their lower level CFG
690 /// equivalent blocks.
createLowerAffinePass()691 std::unique_ptr<Pass> mlir::createLowerAffinePass() {
692 return std::make_unique<LowerAffinePass>();
693 }
694