• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file lowers affine constructs (If and For statements, AffineApply
10 // operations) within a function into their standard If and For equivalent ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
15 
16 #include "../PassDetail.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/Dialect/Vector/VectorOps.h"
21 #include "mlir/IR/AffineExprVisitor.h"
22 #include "mlir/IR/BlockAndValueMapping.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "mlir/Transforms/Passes.h"
29 
30 using namespace mlir;
31 using namespace mlir::vector;
32 
33 namespace {
34 /// Visit affine expressions recursively and build the sequence of operations
35 /// that correspond to it.  Visitation functions return an Value of the
36 /// expression subtree they visited or `nullptr` on error.
37 class AffineApplyExpander
38     : public AffineExprVisitor<AffineApplyExpander, Value> {
39 public:
40   /// This internal class expects arguments to be non-null, checks must be
41   /// performed at the call site.
AffineApplyExpander(OpBuilder & builder,ValueRange dimValues,ValueRange symbolValues,Location loc)42   AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
43                       ValueRange symbolValues, Location loc)
44       : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
45         loc(loc) {}
46 
buildBinaryExpr(AffineBinaryOpExpr expr)47   template <typename OpTy> Value buildBinaryExpr(AffineBinaryOpExpr expr) {
48     auto lhs = visit(expr.getLHS());
49     auto rhs = visit(expr.getRHS());
50     if (!lhs || !rhs)
51       return nullptr;
52     auto op = builder.create<OpTy>(loc, lhs, rhs);
53     return op.getResult();
54   }
55 
visitAddExpr(AffineBinaryOpExpr expr)56   Value visitAddExpr(AffineBinaryOpExpr expr) {
57     return buildBinaryExpr<AddIOp>(expr);
58   }
59 
visitMulExpr(AffineBinaryOpExpr expr)60   Value visitMulExpr(AffineBinaryOpExpr expr) {
61     return buildBinaryExpr<MulIOp>(expr);
62   }
63 
64   /// Euclidean modulo operation: negative RHS is not allowed.
65   /// Remainder of the euclidean integer division is always non-negative.
66   ///
67   /// Implemented as
68   ///
69   ///     a mod b =
70   ///         let remainder = srem a, b;
71   ///             negative = a < 0 in
72   ///         select negative, remainder + b, remainder.
visitModExpr(AffineBinaryOpExpr expr)73   Value visitModExpr(AffineBinaryOpExpr expr) {
74     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
75     if (!rhsConst) {
76       emitError(
77           loc,
78           "semi-affine expressions (modulo by non-const) are not supported");
79       return nullptr;
80     }
81     if (rhsConst.getValue() <= 0) {
82       emitError(loc, "modulo by non-positive value is not supported");
83       return nullptr;
84     }
85 
86     auto lhs = visit(expr.getLHS());
87     auto rhs = visit(expr.getRHS());
88     assert(lhs && rhs && "unexpected affine expr lowering failure");
89 
90     Value remainder = builder.create<SignedRemIOp>(loc, lhs, rhs);
91     Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
92     Value isRemainderNegative =
93         builder.create<CmpIOp>(loc, CmpIPredicate::slt, remainder, zeroCst);
94     Value correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
95     Value result = builder.create<SelectOp>(loc, isRemainderNegative,
96                                             correctedRemainder, remainder);
97     return result;
98   }
99 
100   /// Floor division operation (rounds towards negative infinity).
101   ///
102   /// For positive divisors, it can be implemented without branching and with a
103   /// single division operation as
104   ///
105   ///        a floordiv b =
106   ///            let negative = a < 0 in
107   ///            let absolute = negative ? -a - 1 : a in
108   ///            let quotient = absolute / b in
109   ///                negative ? -quotient - 1 : quotient
visitFloorDivExpr(AffineBinaryOpExpr expr)110   Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
111     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
112     if (!rhsConst) {
113       emitError(
114           loc,
115           "semi-affine expressions (division by non-const) are not supported");
116       return nullptr;
117     }
118     if (rhsConst.getValue() <= 0) {
119       emitError(loc, "division by non-positive value is not supported");
120       return nullptr;
121     }
122 
123     auto lhs = visit(expr.getLHS());
124     auto rhs = visit(expr.getRHS());
125     assert(lhs && rhs && "unexpected affine expr lowering failure");
126 
127     Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
128     Value noneCst = builder.create<ConstantIndexOp>(loc, -1);
129     Value negative =
130         builder.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, zeroCst);
131     Value negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
132     Value dividend =
133         builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
134     Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
135     Value correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient);
136     Value result =
137         builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
138     return result;
139   }
140 
141   /// Ceiling division operation (rounds towards positive infinity).
142   ///
143   /// For positive divisors, it can be implemented without branching and with a
144   /// single division operation as
145   ///
146   ///     a ceildiv b =
147   ///         let negative = a <= 0 in
148   ///         let absolute = negative ? -a : a - 1 in
149   ///         let quotient = absolute / b in
150   ///             negative ? -quotient : quotient + 1
visitCeilDivExpr(AffineBinaryOpExpr expr)151   Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
152     auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
153     if (!rhsConst) {
154       emitError(loc) << "semi-affine expressions (division by non-const) are "
155                         "not supported";
156       return nullptr;
157     }
158     if (rhsConst.getValue() <= 0) {
159       emitError(loc, "division by non-positive value is not supported");
160       return nullptr;
161     }
162     auto lhs = visit(expr.getLHS());
163     auto rhs = visit(expr.getRHS());
164     assert(lhs && rhs && "unexpected affine expr lowering failure");
165 
166     Value zeroCst = builder.create<ConstantIndexOp>(loc, 0);
167     Value oneCst = builder.create<ConstantIndexOp>(loc, 1);
168     Value nonPositive =
169         builder.create<CmpIOp>(loc, CmpIPredicate::sle, lhs, zeroCst);
170     Value negated = builder.create<SubIOp>(loc, zeroCst, lhs);
171     Value decremented = builder.create<SubIOp>(loc, lhs, oneCst);
172     Value dividend =
173         builder.create<SelectOp>(loc, nonPositive, negated, decremented);
174     Value quotient = builder.create<SignedDivIOp>(loc, dividend, rhs);
175     Value negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient);
176     Value incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst);
177     Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
178                                             incrementedQuotient);
179     return result;
180   }
181 
visitConstantExpr(AffineConstantExpr expr)182   Value visitConstantExpr(AffineConstantExpr expr) {
183     auto valueAttr =
184         builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
185     auto op =
186         builder.create<ConstantOp>(loc, builder.getIndexType(), valueAttr);
187     return op.getResult();
188   }
189 
visitDimExpr(AffineDimExpr expr)190   Value visitDimExpr(AffineDimExpr expr) {
191     assert(expr.getPosition() < dimValues.size() &&
192            "affine dim position out of range");
193     return dimValues[expr.getPosition()];
194   }
195 
visitSymbolExpr(AffineSymbolExpr expr)196   Value visitSymbolExpr(AffineSymbolExpr expr) {
197     assert(expr.getPosition() < symbolValues.size() &&
198            "symbol dim position out of range");
199     return symbolValues[expr.getPosition()];
200   }
201 
202 private:
203   OpBuilder &builder;
204   ValueRange dimValues;
205   ValueRange symbolValues;
206 
207   Location loc;
208 };
209 } // namespace
210 
211 /// Create a sequence of operations that implement the `expr` applied to the
212 /// given dimension and symbol values.
expandAffineExpr(OpBuilder & builder,Location loc,AffineExpr expr,ValueRange dimValues,ValueRange symbolValues)213 mlir::Value mlir::expandAffineExpr(OpBuilder &builder, Location loc,
214                                    AffineExpr expr, ValueRange dimValues,
215                                    ValueRange symbolValues) {
216   return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
217 }
218 
219 /// Create a sequence of operations that implement the `affineMap` applied to
220 /// the given `operands` (as it it were an AffineApplyOp).
expandAffineMap(OpBuilder & builder,Location loc,AffineMap affineMap,ValueRange operands)221 Optional<SmallVector<Value, 8>> mlir::expandAffineMap(OpBuilder &builder,
222                                                       Location loc,
223                                                       AffineMap affineMap,
224                                                       ValueRange operands) {
225   auto numDims = affineMap.getNumDims();
226   auto expanded = llvm::to_vector<8>(
227       llvm::map_range(affineMap.getResults(),
228                       [numDims, &builder, loc, operands](AffineExpr expr) {
229                         return expandAffineExpr(builder, loc, expr,
230                                                 operands.take_front(numDims),
231                                                 operands.drop_front(numDims));
232                       }));
233   if (llvm::all_of(expanded, [](Value v) { return v; }))
234     return expanded;
235   return None;
236 }
237 
238 /// Given a range of values, emit the code that reduces them with "min" or "max"
239 /// depending on the provided comparison predicate.  The predicate defines which
240 /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
241 /// `cmpi` operation followed by the `select` operation:
242 ///
243 ///   %cond   = cmpi "predicate" %v0, %v1
244 ///   %result = select %cond, %v0, %v1
245 ///
246 /// Multiple values are scanned in a linear sequence.  This creates a data
247 /// dependences that wouldn't exist in a tree reduction, but is easier to
248 /// recognize as a reduction by the subsequent passes.
buildMinMaxReductionSeq(Location loc,CmpIPredicate predicate,ValueRange values,OpBuilder & builder)249 static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
250                                      ValueRange values, OpBuilder &builder) {
251   assert(!llvm::empty(values) && "empty min/max chain");
252 
253   auto valueIt = values.begin();
254   Value value = *valueIt++;
255   for (; valueIt != values.end(); ++valueIt) {
256     auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
257     value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
258   }
259 
260   return value;
261 }
262 
263 /// Emit instructions that correspond to computing the maximum value among the
264 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMax(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)265 static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
266                                ValueRange operands) {
267   if (auto values = expandAffineMap(builder, loc, map, operands))
268     return buildMinMaxReductionSeq(loc, CmpIPredicate::sgt, *values, builder);
269   return nullptr;
270 }
271 
272 /// Emit instructions that correspond to computing the minimum value among the
273 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMin(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)274 static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
275                                ValueRange operands) {
276   if (auto values = expandAffineMap(builder, loc, map, operands))
277     return buildMinMaxReductionSeq(loc, CmpIPredicate::slt, *values, builder);
278   return nullptr;
279 }
280 
281 /// Emit instructions that correspond to the affine map in the upper bound
282 /// applied to the respective operands, and compute the minimum value across
283 /// the results.
lowerAffineUpperBound(AffineForOp op,OpBuilder & builder)284 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
285   return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
286                            op.getUpperBoundOperands());
287 }
288 
289 /// Emit instructions that correspond to the affine map in the lower bound
290 /// applied to the respective operands, and compute the maximum value across
291 /// the results.
lowerAffineLowerBound(AffineForOp op,OpBuilder & builder)292 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
293   return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
294                            op.getLowerBoundOperands());
295 }
296 
297 namespace {
298 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
299 public:
300   using OpRewritePattern<AffineMinOp>::OpRewritePattern;
301 
matchAndRewrite(AffineMinOp op,PatternRewriter & rewriter) const302   LogicalResult matchAndRewrite(AffineMinOp op,
303                                 PatternRewriter &rewriter) const override {
304     Value reduced =
305         lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
306     if (!reduced)
307       return failure();
308 
309     rewriter.replaceOp(op, reduced);
310     return success();
311   }
312 };
313 
314 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
315 public:
316   using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
317 
matchAndRewrite(AffineMaxOp op,PatternRewriter & rewriter) const318   LogicalResult matchAndRewrite(AffineMaxOp op,
319                                 PatternRewriter &rewriter) const override {
320     Value reduced =
321         lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
322     if (!reduced)
323       return failure();
324 
325     rewriter.replaceOp(op, reduced);
326     return success();
327   }
328 };
329 
330 /// Affine yields ops are removed.
331 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
332 public:
333   using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
334 
matchAndRewrite(AffineYieldOp op,PatternRewriter & rewriter) const335   LogicalResult matchAndRewrite(AffineYieldOp op,
336                                 PatternRewriter &rewriter) const override {
337     rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
338     return success();
339   }
340 };
341 
342 class AffineForLowering : public OpRewritePattern<AffineForOp> {
343 public:
344   using OpRewritePattern<AffineForOp>::OpRewritePattern;
345 
matchAndRewrite(AffineForOp op,PatternRewriter & rewriter) const346   LogicalResult matchAndRewrite(AffineForOp op,
347                                 PatternRewriter &rewriter) const override {
348     Location loc = op.getLoc();
349     Value lowerBound = lowerAffineLowerBound(op, rewriter);
350     Value upperBound = lowerAffineUpperBound(op, rewriter);
351     Value step = rewriter.create<ConstantIndexOp>(loc, op.getStep());
352     auto f = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
353     rewriter.eraseBlock(f.getBody());
354     rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end());
355     rewriter.eraseOp(op);
356     return success();
357   }
358 };
359 
360 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
361 /// operation.
362 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
363 public:
364   using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
365 
matchAndRewrite(AffineParallelOp op,PatternRewriter & rewriter) const366   LogicalResult matchAndRewrite(AffineParallelOp op,
367                                 PatternRewriter &rewriter) const override {
368     Location loc = op.getLoc();
369     SmallVector<Value, 8> steps;
370     SmallVector<Value, 8> upperBoundTuple;
371     SmallVector<Value, 8> lowerBoundTuple;
372     // Finding lower and upper bound by expanding the map expression.
373     // Checking if expandAffineMap is not giving NULL.
374     Optional<SmallVector<Value, 8>> upperBound = expandAffineMap(
375         rewriter, loc, op.upperBoundsMap(), op.getUpperBoundsOperands());
376     Optional<SmallVector<Value, 8>> lowerBound = expandAffineMap(
377         rewriter, loc, op.lowerBoundsMap(), op.getLowerBoundsOperands());
378     if (!lowerBound || !upperBound)
379       return failure();
380     upperBoundTuple = *upperBound;
381     lowerBoundTuple = *lowerBound;
382     steps.reserve(op.steps().size());
383     for (Attribute step : op.steps())
384       steps.push_back(rewriter.create<ConstantIndexOp>(
385           loc, step.cast<IntegerAttr>().getInt()));
386     // Creating empty scf.parallel op body with appropriate bounds.
387     auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
388                                                        upperBoundTuple, steps);
389     rewriter.eraseBlock(parallelOp.getBody());
390     rewriter.inlineRegionBefore(op.region(), parallelOp.region(),
391                                 parallelOp.region().end());
392     rewriter.eraseOp(op);
393     return success();
394   }
395 };
396 
397 class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
398 public:
399   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
400 
matchAndRewrite(AffineIfOp op,PatternRewriter & rewriter) const401   LogicalResult matchAndRewrite(AffineIfOp op,
402                                 PatternRewriter &rewriter) const override {
403     auto loc = op.getLoc();
404 
405     // Now we just have to handle the condition logic.
406     auto integerSet = op.getIntegerSet();
407     Value zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0);
408     SmallVector<Value, 8> operands(op.getOperands());
409     auto operandsRef = llvm::makeArrayRef(operands);
410 
411     // Calculate cond as a conjunction without short-circuiting.
412     Value cond = nullptr;
413     for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
414       AffineExpr constraintExpr = integerSet.getConstraint(i);
415       bool isEquality = integerSet.isEq(i);
416 
417       // Build and apply an affine expression
418       auto numDims = integerSet.getNumDims();
419       Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
420                                          operandsRef.take_front(numDims),
421                                          operandsRef.drop_front(numDims));
422       if (!affResult)
423         return failure();
424       auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge;
425       Value cmpVal =
426           rewriter.create<CmpIOp>(loc, pred, affResult, zeroConstant);
427       cond =
428           cond ? rewriter.create<AndOp>(loc, cond, cmpVal).getResult() : cmpVal;
429     }
430     cond = cond ? cond
431                 : rewriter.create<ConstantIntOp>(loc, /*value=*/1, /*width=*/1);
432 
433     bool hasElseRegion = !op.elseRegion().empty();
434     auto ifOp = rewriter.create<scf::IfOp>(loc, cond, hasElseRegion);
435     rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back());
436     rewriter.eraseBlock(&ifOp.thenRegion().back());
437     if (hasElseRegion) {
438       rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back());
439       rewriter.eraseBlock(&ifOp.elseRegion().back());
440     }
441 
442     // Ok, we're done!
443     rewriter.eraseOp(op);
444     return success();
445   }
446 };
447 
448 /// Convert an "affine.apply" operation into a sequence of arithmetic
449 /// operations using the StandardOps dialect.
450 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
451 public:
452   using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
453 
matchAndRewrite(AffineApplyOp op,PatternRewriter & rewriter) const454   LogicalResult matchAndRewrite(AffineApplyOp op,
455                                 PatternRewriter &rewriter) const override {
456     auto maybeExpandedMap =
457         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
458                         llvm::to_vector<8>(op.getOperands()));
459     if (!maybeExpandedMap)
460       return failure();
461     rewriter.replaceOp(op, *maybeExpandedMap);
462     return success();
463   }
464 };
465 
466 /// Apply the affine map from an 'affine.load' operation to its operands, and
467 /// feed the results to a newly created 'std.load' operation (which replaces the
468 /// original 'affine.load').
469 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
470 public:
471   using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
472 
matchAndRewrite(AffineLoadOp op,PatternRewriter & rewriter) const473   LogicalResult matchAndRewrite(AffineLoadOp op,
474                                 PatternRewriter &rewriter) const override {
475     // Expand affine map from 'affineLoadOp'.
476     SmallVector<Value, 8> indices(op.getMapOperands());
477     auto resultOperands =
478         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
479     if (!resultOperands)
480       return failure();
481 
482     // Build std.load memref[expandedMap.results].
483     rewriter.replaceOpWithNewOp<LoadOp>(op, op.getMemRef(), *resultOperands);
484     return success();
485   }
486 };
487 
488 /// Apply the affine map from an 'affine.prefetch' operation to its operands,
489 /// and feed the results to a newly created 'std.prefetch' operation (which
490 /// replaces the original 'affine.prefetch').
491 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
492 public:
493   using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
494 
matchAndRewrite(AffinePrefetchOp op,PatternRewriter & rewriter) const495   LogicalResult matchAndRewrite(AffinePrefetchOp op,
496                                 PatternRewriter &rewriter) const override {
497     // Expand affine map from 'affinePrefetchOp'.
498     SmallVector<Value, 8> indices(op.getMapOperands());
499     auto resultOperands =
500         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
501     if (!resultOperands)
502       return failure();
503 
504     // Build std.prefetch memref[expandedMap.results].
505     rewriter.replaceOpWithNewOp<PrefetchOp>(op, op.memref(), *resultOperands,
506                                             op.isWrite(), op.localityHint(),
507                                             op.isDataCache());
508     return success();
509   }
510 };
511 
512 /// Apply the affine map from an 'affine.store' operation to its operands, and
513 /// feed the results to a newly created 'std.store' operation (which replaces
514 /// the original 'affine.store').
515 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
516 public:
517   using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
518 
matchAndRewrite(AffineStoreOp op,PatternRewriter & rewriter) const519   LogicalResult matchAndRewrite(AffineStoreOp op,
520                                 PatternRewriter &rewriter) const override {
521     // Expand affine map from 'affineStoreOp'.
522     SmallVector<Value, 8> indices(op.getMapOperands());
523     auto maybeExpandedMap =
524         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
525     if (!maybeExpandedMap)
526       return failure();
527 
528     // Build std.store valueToStore, memref[expandedMap.results].
529     rewriter.replaceOpWithNewOp<StoreOp>(op, op.getValueToStore(),
530                                          op.getMemRef(), *maybeExpandedMap);
531     return success();
532   }
533 };
534 
535 /// Apply the affine maps from an 'affine.dma_start' operation to each of their
536 /// respective map operands, and feed the results to a newly created
537 /// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
538 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
539 public:
540   using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
541 
matchAndRewrite(AffineDmaStartOp op,PatternRewriter & rewriter) const542   LogicalResult matchAndRewrite(AffineDmaStartOp op,
543                                 PatternRewriter &rewriter) const override {
544     SmallVector<Value, 8> operands(op.getOperands());
545     auto operandsRef = llvm::makeArrayRef(operands);
546 
547     // Expand affine map for DMA source memref.
548     auto maybeExpandedSrcMap = expandAffineMap(
549         rewriter, op.getLoc(), op.getSrcMap(),
550         operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
551     if (!maybeExpandedSrcMap)
552       return failure();
553     // Expand affine map for DMA destination memref.
554     auto maybeExpandedDstMap = expandAffineMap(
555         rewriter, op.getLoc(), op.getDstMap(),
556         operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
557     if (!maybeExpandedDstMap)
558       return failure();
559     // Expand affine map for DMA tag memref.
560     auto maybeExpandedTagMap = expandAffineMap(
561         rewriter, op.getLoc(), op.getTagMap(),
562         operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
563     if (!maybeExpandedTagMap)
564       return failure();
565 
566     // Build std.dma_start operation with affine map results.
567     rewriter.replaceOpWithNewOp<DmaStartOp>(
568         op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
569         *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
570         *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
571     return success();
572   }
573 };
574 
575 /// Apply the affine map from an 'affine.dma_wait' operation tag memref,
576 /// and feed the results to a newly created 'std.dma_wait' operation (which
577 /// replaces the original 'affine.dma_wait').
578 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
579 public:
580   using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
581 
matchAndRewrite(AffineDmaWaitOp op,PatternRewriter & rewriter) const582   LogicalResult matchAndRewrite(AffineDmaWaitOp op,
583                                 PatternRewriter &rewriter) const override {
584     // Expand affine map for DMA tag memref.
585     SmallVector<Value, 8> indices(op.getTagIndices());
586     auto maybeExpandedTagMap =
587         expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
588     if (!maybeExpandedTagMap)
589       return failure();
590 
591     // Build std.dma_wait operation with affine map results.
592     rewriter.replaceOpWithNewOp<DmaWaitOp>(
593         op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
594     return success();
595   }
596 };
597 
598 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
599 /// and feed the results to a newly created 'vector.transfer_read' operation
600 /// (which replaces the original 'affine.vector_load').
601 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
602 public:
603   using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
604 
matchAndRewrite(AffineVectorLoadOp op,PatternRewriter & rewriter) const605   LogicalResult matchAndRewrite(AffineVectorLoadOp op,
606                                 PatternRewriter &rewriter) const override {
607     // Expand affine map from 'affineVectorLoadOp'.
608     SmallVector<Value, 8> indices(op.getMapOperands());
609     auto resultOperands =
610         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
611     if (!resultOperands)
612       return failure();
613 
614     // Build vector.transfer_read memref[expandedMap.results].
615     rewriter.replaceOpWithNewOp<TransferReadOp>(
616         op, op.getVectorType(), op.getMemRef(), *resultOperands);
617     return success();
618   }
619 };
620 
621 /// Apply the affine map from an 'affine.vector_store' operation to its
622 /// operands, and feed the results to a newly created 'vector.transfer_write'
623 /// operation (which replaces the original 'affine.vector_store').
624 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
625 public:
626   using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
627 
matchAndRewrite(AffineVectorStoreOp op,PatternRewriter & rewriter) const628   LogicalResult matchAndRewrite(AffineVectorStoreOp op,
629                                 PatternRewriter &rewriter) const override {
630     // Expand affine map from 'affineVectorStoreOp'.
631     SmallVector<Value, 8> indices(op.getMapOperands());
632     auto maybeExpandedMap =
633         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
634     if (!maybeExpandedMap)
635       return failure();
636 
637     rewriter.replaceOpWithNewOp<TransferWriteOp>(
638         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
639     return success();
640   }
641 };
642 
643 } // end namespace
644 
populateAffineToStdConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)645 void mlir::populateAffineToStdConversionPatterns(
646     OwningRewritePatternList &patterns, MLIRContext *ctx) {
647   // clang-format off
648   patterns.insert<
649       AffineApplyLowering,
650       AffineDmaStartLowering,
651       AffineDmaWaitLowering,
652       AffineLoadLowering,
653       AffineMinLowering,
654       AffineMaxLowering,
655       AffineParallelLowering,
656       AffinePrefetchLowering,
657       AffineStoreLowering,
658       AffineForLowering,
659       AffineIfLowering,
660       AffineYieldOpLowering>(ctx);
661   // clang-format on
662 }
663 
populateAffineToVectorConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)664 void mlir::populateAffineToVectorConversionPatterns(
665     OwningRewritePatternList &patterns, MLIRContext *ctx) {
666   // clang-format off
667   patterns.insert<
668       AffineVectorLoadLowering,
669       AffineVectorStoreLowering>(ctx);
670   // clang-format on
671 }
672 
673 namespace {
674 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
runOnOperation()675   void runOnOperation() override {
676     OwningRewritePatternList patterns;
677     populateAffineToStdConversionPatterns(patterns, &getContext());
678     populateAffineToVectorConversionPatterns(patterns, &getContext());
679     ConversionTarget target(getContext());
680     target
681         .addLegalDialect<scf::SCFDialect, StandardOpsDialect, VectorDialect>();
682     if (failed(applyPartialConversion(getOperation(), target,
683                                       std::move(patterns))))
684       signalPassFailure();
685   }
686 };
687 } // namespace
688 
689 /// Lowers If and For operations within a function into their lower level CFG
690 /// equivalent blocks.
createLowerAffinePass()691 std::unique_ptr<Pass> mlir::createLowerAffinePass() {
692   return std::make_unique<LowerAffinePass>();
693 }
694