1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect pattern rewriters that make GPU op
10 // within a region execute asynchronously.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/GPU/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Passes.h"
18 #include "mlir/Dialect/GPU/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/TypeSwitch.h"
27
28 using namespace mlir;
29 namespace {
30 class GpuAsyncRegionPass : public GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
31 struct ThreadTokenCallback;
32 struct DeferWaitCallback;
33 void runOnFunction() override;
34 };
35 } // namespace
36
isTerminator(Operation * op)37 static bool isTerminator(Operation *op) { return !op->isKnownNonTerminator(); }
hasSideEffects(Operation * op)38 static bool hasSideEffects(Operation *op) {
39 return !MemoryEffectOpInterface::hasNoEffect(op);
40 }
41
42 // Region walk callback which makes GPU ops implementing the AsyncOpInterface
43 // execute asynchronously.
44 struct GpuAsyncRegionPass::ThreadTokenCallback {
ThreadTokenCallbackGpuAsyncRegionPass::ThreadTokenCallback45 ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
46
47 // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
48 // create a current token (unless it already exists), and 'thread' that token
49 // through the `op` so that it executes asynchronously.
50 //
51 // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
52 // host-synchronize execution. A `!gpu.async.token` will therefore only be
53 // used inside of its block and GPU execution will always synchronize with
54 // the host at block boundaries.
operator ()GpuAsyncRegionPass::ThreadTokenCallback55 WalkResult operator()(Operation *op) {
56 if (isa<gpu::LaunchOp>(op))
57 return op->emitOpError("replace with gpu.launch_func first");
58 if (isa<gpu::WaitOp>(op))
59 return op->emitOpError("unexpected pre-existing gpu.wait");
60 builder.setInsertionPoint(op);
61 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
62 return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
63 if (!currentToken)
64 return success();
65 // Insert host synchronization before terminator or op with side effects.
66 if (isTerminator(op) || hasSideEffects(op))
67 currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
68 return success();
69 }
70
71 private:
72 // Replaces asyncOp with a clone that returns a token.
rewriteAsyncOpGpuAsyncRegionPass::ThreadTokenCallback73 LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
74 auto *op = asyncOp.getOperation();
75 if (asyncOp.getAsyncToken())
76 // TODO: Support ops that are already async.
77 return op->emitOpError("is already async");
78 if (op->getNumRegions() > 0)
79 return op->emitOpError("regions are not supported");
80
81 // If there is no current token, insert a `gpu.wait async` without
82 // dependencies to create one.
83 if (!currentToken)
84 currentToken = createWaitOp(op->getLoc(), tokenType, {});
85 asyncOp.addAsyncDependency(currentToken);
86
87 // Clone the op to return a token in addition to the other results.
88 SmallVector<Type, 1> resultTypes = {tokenType};
89 resultTypes.reserve(1 + op->getNumResults());
90 copy(op->getResultTypes(), std::back_inserter(resultTypes));
91 auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
92 op->getOperands(), op->getMutableAttrDict(),
93 op->getSuccessors());
94
95 // Replace the op with the async clone.
96 auto results = newOp->getResults();
97 currentToken = results.front();
98 builder.insert(newOp);
99 op->replaceAllUsesWith(results.drop_front());
100 op->erase();
101
102 return success();
103 }
104
createWaitOpGpuAsyncRegionPass::ThreadTokenCallback105 Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
106 return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken();
107 }
108
109 OpBuilder builder;
110 const Type tokenType = builder.getType<gpu::AsyncTokenType>();
111 // The token that represents the current asynchronous dependency. It's valid
112 // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
113 // In between, each gpu::AsyncOpInterface depends on the current token and
114 // produces the new one.
115 Value currentToken = {};
116 };
117
118 // Callback for `async.execute` ops which tries to push the contained
119 // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
120 struct GpuAsyncRegionPass::DeferWaitCallback {
121 // If the `executeOp`s token is used only in `async.execute` or `async.await`
122 // ops, add the region's last `gpu.wait` op to the worklist if it is
123 // synchronous and is the last op with side effects.
operator ()GpuAsyncRegionPass::DeferWaitCallback124 void operator()(async::ExecuteOp executeOp) {
125 if (!areAllUsersExecuteOrAwait(executeOp.token()))
126 return;
127 // async.execute's region is currently restricted to one block.
128 for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
129 if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
130 if (!waitOp.asyncToken())
131 worklist.push_back(waitOp);
132 return;
133 }
134 if (hasSideEffects(&op))
135 return;
136 }
137 }
138
139 // The destructor performs the actual rewrite work.
~DeferWaitCallbackGpuAsyncRegionPass::DeferWaitCallback140 ~DeferWaitCallback() {
141 for (size_t i = 0; i < worklist.size(); ++i) {
142 auto waitOp = worklist[i];
143 auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
144 auto numDependencies = waitOp.asyncDependencies().size();
145
146 // Erase `gpu.wait` and return async dependencies from region instead.
147 auto &yieldOp = executeOp.getBody()->getOperations().back();
148 yieldOp.insertOperands(yieldOp.getNumOperands(),
149 waitOp.asyncDependencies());
150 waitOp.erase();
151 auto asyncTokens = addAsyncTokenResults(executeOp, numDependencies);
152
153 // Add the async dependency to each user of the `async.execute` token.
154 for (Operation *user : executeOp.token().getUsers())
155 addAsyncDependencyAfter(asyncTokens, user);
156 }
157 }
158
159 private:
160 // Append `count` `!async.value<!gpu.async.token>` results to `executeOp`.
addAsyncTokenResultsGpuAsyncRegionPass::DeferWaitCallback161 static ValueRange addAsyncTokenResults(async::ExecuteOp &executeOp,
162 unsigned count) {
163 auto numResults = executeOp.getNumResults() + count;
164
165 // Construct new result type list with `count` additional types.
166 SmallVector<Type, 2> resultTypes;
167 resultTypes.reserve(numResults);
168 copy(executeOp.getResultTypes(), std::back_inserter(resultTypes));
169 OpBuilder builder(executeOp);
170 auto tokenType = builder.getType<gpu::AsyncTokenType>();
171 resultTypes.resize(numResults, tokenType);
172
173 // Clone executeOp with the extra `!gpu.async.token` results.
174 auto newOp = builder.create<async::ExecuteOp>(
175 executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
176 executeOp.dependencies(), executeOp.operands());
177 BlockAndValueMapping mapper;
178 newOp.getRegion().getBlocks().clear();
179 executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
180
181 // Replace executeOp with cloned one.
182 executeOp.getOperation()->replaceAllUsesWith(
183 newOp.getResults().drop_back(count));
184 executeOp.erase();
185 executeOp = newOp;
186
187 // Return the new result values.
188 return executeOp.getResults().take_back(count);
189 }
190
191 // Returns whether all token users are either 'async.execute' or 'async.await'
192 // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
193 // 'async.execute' body to it's users. Specifically, we do not allow
194 // terminator users, because it could mean that the `async.execute` is inside
195 // control flow code.
areAllUsersExecuteOrAwaitGpuAsyncRegionPass::DeferWaitCallback196 static bool areAllUsersExecuteOrAwait(Value token) {
197 return llvm::all_of(token.getUsers(), [](Operation *user) {
198 return isa<async::ExecuteOp, async::AwaitOp>(user);
199 });
200 }
201
202 // Add the `asyncToken` as dependency as needed after `op`.
addAsyncDependencyAfterGpuAsyncRegionPass::DeferWaitCallback203 void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
204 OpBuilder builder(op->getContext());
205 auto loc = op->getLoc();
206
207 Block::iterator it;
208 SmallVector<Value, 1> tokens;
209 tokens.reserve(asyncTokens.size());
210 TypeSwitch<Operation *>(op)
211 .Case<async::AwaitOp>([&](auto awaitOp) {
212 // Add async.await ops to wait for the !gpu.async.tokens.
213 builder.setInsertionPointAfter(op);
214 for (auto asyncToken : asyncTokens)
215 tokens.push_back(
216 builder.create<async::AwaitOp>(loc, asyncToken).result());
217 // Set `it` after the inserted async.await ops.
218 it = builder.getInsertionPoint();
219 })
220 .Case<async::ExecuteOp>([&](auto executeOp) {
221 // Set `it` to the beginning of the region and add asyncTokens to the
222 // async.execute operands.
223 it = executeOp.getBody()->begin();
224 executeOp.operandsMutable().append(asyncTokens);
225 SmallVector<Type, 1> tokenTypes(
226 asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
227 copy(executeOp.getBody()->addArguments(tokenTypes),
228 std::back_inserter(tokens));
229 });
230
231 // Advance `it` to terminator or op with side-effects.
232 it = std::find_if(it, Block::iterator(), [](Operation &op) {
233 return isTerminator(&op) || hasSideEffects(&op);
234 });
235
236 // If `op` implements the AsyncOpInterface, add `token` to the list of async
237 // dependencies.
238 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
239 for (auto token : tokens)
240 asyncOp.addAsyncDependency(token);
241 return;
242 }
243
244 // Otherwise, insert a gpu.wait before 'it'.
245 builder.setInsertionPoint(it->getBlock(), it);
246 auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
247
248 // If the new waitOp is at the end of an async.execute region, add it to the
249 // worklist. 'operator()(executeOp)' would do the same, but this is faster.
250 auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
251 if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) &&
252 !it->getNextNode())
253 worklist.push_back(waitOp);
254 }
255
256 SmallVector<gpu::WaitOp, 8> worklist;
257 };
258
259 // Replaces synchronous GPU ops in the op's region with asynchronous ones and
260 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
261 // execution semantics and that no GPU ops are asynchronous yet.
runOnFunction()262 void GpuAsyncRegionPass::runOnFunction() {
263 if (getFunction()
264 .getRegion()
265 .walk(ThreadTokenCallback(getContext()))
266 .wasInterrupted())
267 return signalPassFailure();
268
269 // Collect gpu.wait ops that we can move out of gpu.execute regions.
270 getFunction().getRegion().walk(DeferWaitCallback());
271 }
272
createGpuAsyncRegionPass()273 std::unique_ptr<OperationPass<FuncOp>> mlir::createGpuAsyncRegionPass() {
274 return std::make_unique<GpuAsyncRegionPass>();
275 }
276