1 //===- AsyncParallelFor.cpp - Implementation of Async Parallel For --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements scf.parallel to src.for + async.execute conversion pass.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22 using namespace mlir;
23 using namespace mlir::async;
24
25 #define DEBUG_TYPE "async-parallel-for"
26
27 namespace {
28
29 // Rewrite scf.parallel operation into multiple concurrent async.execute
30 // operations over non overlapping subranges of the original loop.
31 //
32 // Example:
33 //
34 // scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
35 // "do_some_compute"(%i, %j): () -> ()
36 // }
37 //
38 // Converted to:
39 //
40 // %c0 = constant 0 : index
41 // %c1 = constant 1 : index
42 //
43 // // Compute blocks sizes for each induction variable.
44 // %num_blocks_i = ... : index
45 // %num_blocks_j = ... : index
46 // %block_size_i = ... : index
47 // %block_size_j = ... : index
48 //
49 // // Create an async group to track async execute ops.
50 // %group = async.create_group
51 //
52 // scf.for %bi = %c0 to %num_blocks_i step %c1 {
53 // %block_start_i = ... : index
54 // %block_end_i = ... : index
55 //
56 // scf.for %bj = %c0 to %num_blocks_j step %c1 {
57 // %block_start_j = ... : index
58 // %block_end_j = ... : index
59 //
60 // // Execute the body of original parallel operation for the current
61 // // block.
62 // %token = async.execute {
63 // scf.for %i = %block_start_i to %block_end_i step %si {
64 // scf.for %j = %block_start_j to %block_end_j step %sj {
65 // "do_some_compute"(%i, %j): () -> ()
66 // }
67 // }
68 // }
69 //
70 // // Add produced async token to the group.
71 // async.add_to_group %token, %group
72 // }
73 // }
74 //
75 // // Await completion of all async.execute operations.
76 // async.await_all %group
77 //
78 // In this example outer loop launches inner block level loops as separate async
79 // execute operations which will be executed concurrently.
80 //
81 // At the end it waits for the completiom of all async execute operations.
82 //
83 struct AsyncParallelForRewrite : public OpRewritePattern<scf::ParallelOp> {
84 public:
AsyncParallelForRewrite__anon03e258e20111::AsyncParallelForRewrite85 AsyncParallelForRewrite(MLIRContext *ctx, int numConcurrentAsyncExecute)
86 : OpRewritePattern(ctx),
87 numConcurrentAsyncExecute(numConcurrentAsyncExecute) {}
88
89 LogicalResult matchAndRewrite(scf::ParallelOp op,
90 PatternRewriter &rewriter) const override;
91
92 private:
93 int numConcurrentAsyncExecute;
94 };
95
96 struct AsyncParallelForPass
97 : public AsyncParallelForBase<AsyncParallelForPass> {
98 AsyncParallelForPass() = default;
AsyncParallelForPass__anon03e258e20111::AsyncParallelForPass99 AsyncParallelForPass(int numWorkerThreads) {
100 assert(numWorkerThreads >= 1);
101 numConcurrentAsyncExecute = numWorkerThreads;
102 }
103 void runOnFunction() override;
104 };
105
106 } // namespace
107
108 LogicalResult
matchAndRewrite(scf::ParallelOp op,PatternRewriter & rewriter) const109 AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
110 PatternRewriter &rewriter) const {
111 // We do not currently support rewrite for parallel op with reductions.
112 if (op.getNumReductions() != 0)
113 return failure();
114
115 MLIRContext *ctx = op.getContext();
116 Location loc = op.getLoc();
117
118 // Index constants used below.
119 auto indexTy = IndexType::get(ctx);
120 auto zero = IntegerAttr::get(indexTy, 0);
121 auto one = IntegerAttr::get(indexTy, 1);
122 auto c0 = rewriter.create<ConstantOp>(loc, indexTy, zero);
123 auto c1 = rewriter.create<ConstantOp>(loc, indexTy, one);
124
125 // Shorthand for signed integer ceil division operation.
126 auto divup = [&](Value x, Value y) -> Value {
127 return rewriter.create<SignedCeilDivIOp>(loc, x, y);
128 };
129
130 // Compute trip count for each loop induction variable:
131 // tripCount = divUp(upperBound - lowerBound, step);
132 SmallVector<Value, 4> tripCounts(op.getNumLoops());
133 for (size_t i = 0; i < op.getNumLoops(); ++i) {
134 auto lb = op.lowerBound()[i];
135 auto ub = op.upperBound()[i];
136 auto step = op.step()[i];
137 auto range = rewriter.create<SubIOp>(loc, ub, lb);
138 tripCounts[i] = divup(range, step);
139 }
140
141 // The target number of concurrent async.execute ops.
142 auto numExecuteOps = rewriter.create<ConstantOp>(
143 loc, indexTy, IntegerAttr::get(indexTy, numConcurrentAsyncExecute));
144
145 // Blocks sizes configuration for each induction variable.
146
147 // We try to use maximum available concurrency in outer dimensions first
148 // (assuming that parallel induction variables are corresponding to some
149 // multidimensional access, e.g. in (%d0, %d1, ..., %dn) = (<from>) to (<to>)
150 // we will try to parallelize iteration along the %d0. If %d0 is too small,
151 // we'll parallelize iteration over %d1, and so on.
152 SmallVector<Value, 4> targetNumBlocks(op.getNumLoops());
153 SmallVector<Value, 4> blockSize(op.getNumLoops());
154 SmallVector<Value, 4> numBlocks(op.getNumLoops());
155
156 // Compute block size and number of blocks along the first induction variable.
157 targetNumBlocks[0] = numExecuteOps;
158 blockSize[0] = divup(tripCounts[0], targetNumBlocks[0]);
159 numBlocks[0] = divup(tripCounts[0], blockSize[0]);
160
161 // Assign remaining available concurrency to other induction variables.
162 for (size_t i = 1; i < op.getNumLoops(); ++i) {
163 targetNumBlocks[i] = divup(targetNumBlocks[i - 1], numBlocks[i - 1]);
164 blockSize[i] = divup(tripCounts[i], targetNumBlocks[i]);
165 numBlocks[i] = divup(tripCounts[i], blockSize[i]);
166 }
167
168 // Create an async.group to wait on all async tokens from async execute ops.
169 auto group = rewriter.create<CreateGroupOp>(loc, GroupType::get(ctx));
170
171 // Build a scf.for loop nest from the parallel operation.
172
173 // Lower/upper bounds for nest block level computations.
174 SmallVector<Value, 4> blockLowerBounds(op.getNumLoops());
175 SmallVector<Value, 4> blockUpperBounds(op.getNumLoops());
176 SmallVector<Value, 4> blockInductionVars(op.getNumLoops());
177
178 using LoopBodyBuilder =
179 std::function<void(OpBuilder &, Location, Value, ValueRange)>;
180 using LoopBuilder = std::function<LoopBodyBuilder(size_t loopIdx)>;
181
182 // Builds inner loop nest inside async.execute operation that does all the
183 // work concurrently.
184 LoopBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
185 return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
186 blockInductionVars[loopIdx] = iv;
187
188 // Continute building async loop nest.
189 if (loopIdx < op.getNumLoops() - 1) {
190 b.create<scf::ForOp>(
191 loc, blockLowerBounds[loopIdx + 1], blockUpperBounds[loopIdx + 1],
192 op.step()[loopIdx + 1], ValueRange(), workLoopBuilder(loopIdx + 1));
193 b.create<scf::YieldOp>(loc);
194 return;
195 }
196
197 // Copy the body of the parallel op with new loop bounds.
198 BlockAndValueMapping mapping;
199 mapping.map(op.getInductionVars(), blockInductionVars);
200
201 for (auto &bodyOp : op.getLoopBody().getOps())
202 b.clone(bodyOp, mapping);
203 };
204 };
205
206 // Builds a loop nest that does async execute op dispatching.
207 LoopBuilder asyncLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
208 return [&, loopIdx](OpBuilder &b, Location loc, Value iv, ValueRange args) {
209 auto lb = op.lowerBound()[loopIdx];
210 auto ub = op.upperBound()[loopIdx];
211 auto step = op.step()[loopIdx];
212
213 // Compute lower bound for the current block:
214 // blockLowerBound = iv * blockSize * step + lowerBound
215 auto s0 = b.create<MulIOp>(loc, iv, blockSize[loopIdx]);
216 auto s1 = b.create<MulIOp>(loc, s0, step);
217 auto s2 = b.create<AddIOp>(loc, s1, lb);
218 blockLowerBounds[loopIdx] = s2;
219
220 // Compute upper bound for the current block:
221 // blockUpperBound = min(upperBound,
222 // blockLowerBound + blockSize * step)
223 auto e0 = b.create<MulIOp>(loc, blockSize[loopIdx], step);
224 auto e1 = b.create<AddIOp>(loc, e0, s2);
225 auto e2 = b.create<CmpIOp>(loc, CmpIPredicate::slt, e1, ub);
226 auto e3 = b.create<SelectOp>(loc, e2, e1, ub);
227 blockUpperBounds[loopIdx] = e3;
228
229 // Continue building async dispatch loop nest.
230 if (loopIdx < op.getNumLoops() - 1) {
231 b.create<scf::ForOp>(loc, c0, numBlocks[loopIdx + 1], c1, ValueRange(),
232 asyncLoopBuilder(loopIdx + 1));
233 b.create<scf::YieldOp>(loc);
234 return;
235 }
236
237 // Build the inner loop nest that will do the actual work inside the
238 // `async.execute` body region.
239 auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
240 Location executeLoc,
241 ValueRange executeArgs) {
242 executeBuilder.create<scf::ForOp>(executeLoc, blockLowerBounds[0],
243 blockUpperBounds[0], op.step()[0],
244 ValueRange(), workLoopBuilder(0));
245 executeBuilder.create<async::YieldOp>(executeLoc, ValueRange());
246 };
247
248 auto execute = b.create<ExecuteOp>(
249 loc, /*resultTypes=*/TypeRange(), /*dependencies=*/ValueRange(),
250 /*operands=*/ValueRange(), executeBodyBuilder);
251 auto rankType = IndexType::get(ctx);
252 b.create<AddToGroupOp>(loc, rankType, execute.token(), group.result());
253 b.create<scf::YieldOp>(loc);
254 };
255 };
256
257 // Start building a loop nest from the first induction variable.
258 rewriter.create<scf::ForOp>(loc, c0, numBlocks[0], c1, ValueRange(),
259 asyncLoopBuilder(0));
260
261 // Wait for the completion of all subtasks.
262 rewriter.create<AwaitAllOp>(loc, group.result());
263
264 // Erase the original parallel operation.
265 rewriter.eraseOp(op);
266
267 return success();
268 }
269
runOnFunction()270 void AsyncParallelForPass::runOnFunction() {
271 MLIRContext *ctx = &getContext();
272
273 OwningRewritePatternList patterns;
274 patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
275
276 if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
277 signalPassFailure();
278 }
279
createAsyncParallelForPass()280 std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncParallelForPass() {
281 return std::make_unique<AsyncParallelForPass>();
282 }
283
284 std::unique_ptr<OperationPass<FuncOp>>
createAsyncParallelForPass(int numWorkerThreads)285 mlir::createAsyncParallelForPass(int numWorkerThreads) {
286 return std::make_unique<AsyncParallelForPass>(numWorkerThreads);
287 }
288