1 //===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements automatic reference counting for Async dialect data
10 // types.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Analysis/Liveness.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "llvm/ADT/SmallSet.h"
22
23 using namespace mlir;
24 using namespace mlir::async;
25
26 #define DEBUG_TYPE "async-ref-counting"
27
28 namespace {
29
30 class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
31 public:
32 AsyncRefCountingPass() = default;
33 void runOnFunction() override;
34
35 private:
36 /// Adds an automatic reference counting to the `value`.
37 ///
38 /// All values are semantically created with a reference count of +1 and it is
39 /// the responsibility of the last async value user to drop reference count.
40 ///
41 /// Async values created when:
42 /// 1. Operation returns async result (e.g. the result of an
43 /// `async.execute`).
44 /// 2. Async value passed in as a block argument.
45 ///
46 /// To implement automatic reference counting, we must insert a +1 reference
47 /// before each `async.execute` operation using the value, and drop it after
48 /// the last use inside the async body region (we currently drop the reference
49 /// before the `async.yield` terminator).
50 ///
51 /// Automatic reference counting algorithm outline:
52 ///
53 /// 1. `ReturnLike` operations forward the reference counted values without
54 /// modifying the reference count.
55 ///
56 /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
57 /// reference counted values ends, and insert `drop_ref` operations after
58 /// the last use of the value.
59 ///
60 /// 3. Insert `add_ref` before the `async.execute` operation capturing the
61 /// value, and pairing `drop_ref` before the async body region terminator,
62 /// to release the captured reference counted value when execution
63 /// completes.
64 ///
65 /// 4. If the reference counted value is passed only to some of the block
66 /// successors, insert `drop_ref` operations in the beginning of the blocks
67 /// that do not have reference counted value uses.
68 ///
69 ///
70 /// Example:
71 ///
72 /// %token = ...
73 /// async.execute {
74 /// async.await %token : !async.token // await #1
75 /// async.yield
76 /// }
77 /// async.await %token : !async.token // await #2
78 ///
79 /// Based on the liveness analysis await #2 is the last use of the %token,
80 /// however the execution of the async region can be delayed, and to guarantee
81 /// that the %token is still alive when await #1 executes we need to
82 /// explicitly extend its lifetime using `add_ref` operation.
83 ///
84 /// After automatic reference counting:
85 ///
86 /// %token = ...
87 ///
88 /// // Make sure that %token is alive inside async.execute.
89 /// async.add_ref %token {count = 1 : i32} : !async.token
90 ///
91 /// async.execute {
92 /// async.await %token : !async.token // await #1
93 ///
94 /// // Drop the extra reference added to keep %token alive.
95 /// async.drop_ref %token {count = 1 : i32} : !async.token
96 ///
97 /// async.yied
98 /// }
99 /// async.await %token : !async.token // await #2
100 ///
101 /// // Drop the reference after the last use of %token.
102 /// async.drop_ref %token {count = 1 : i32} : !async.token
103 ///
104 LogicalResult addAutomaticRefCounting(Value value);
105 };
106
107 } // namespace
108
addAutomaticRefCounting(Value value)109 LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
110 MLIRContext *ctx = value.getContext();
111 OpBuilder builder(ctx);
112
113 // Set inserton point after the operation producing a value, or at the
114 // beginning of the block if the value defined by the block argument.
115 if (Operation *op = value.getDefiningOp())
116 builder.setInsertionPointAfter(op);
117 else
118 builder.setInsertionPointToStart(value.getParentBlock());
119
120 Location loc = value.getLoc();
121 auto i32 = IntegerType::get(32, ctx);
122
123 // Drop the reference count immediately if the value has no uses.
124 if (value.getUses().empty()) {
125 builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
126 return success();
127 }
128
129 // Use liveness analysis to find the placement of `drop_ref`operation.
130 auto liveness = getAnalysis<Liveness>();
131
132 // We analyse only the blocks of the region that defines the `value`, and do
133 // not check nested blocks attached to operations.
134 //
135 // By analyzing only the `definingRegion` CFG we potentially loose an
136 // opportunity to drop the reference count earlier and can extend the lifetime
137 // of reference counted value longer then it is really required.
138 //
139 // We also assume that all nested regions finish their execution before the
140 // completion of the owner operation. The only exception to this rule is
141 // `async.execute` operation, which is handled explicitly below.
142 Region *definingRegion = value.getParentRegion();
143
144 // ------------------------------------------------------------------------ //
145 // Find blocks where the `value` dies: the value is in `liveIn` set and not
146 // in the `liveOut` set. We place `drop_ref` immediately after the last use
147 // of the `value` in such regions.
148 // ------------------------------------------------------------------------ //
149
150 // Last users of the `value` inside all blocks where the value dies.
151 llvm::SmallSet<Operation *, 4> lastUsers;
152
153 for (Block &block : definingRegion->getBlocks()) {
154 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
155
156 // Value in live input set or was defined in the block.
157 bool liveIn = blockLiveness->isLiveIn(value) ||
158 blockLiveness->getBlock() == value.getParentBlock();
159 if (!liveIn)
160 continue;
161
162 // Value is in the live out set.
163 bool liveOut = blockLiveness->isLiveOut(value);
164 if (liveOut)
165 continue;
166
167 // We proved that `value` dies in the `block`. Now find the last use of the
168 // `value` inside the `block`.
169
170 // Find any user of the `value` inside the block (including uses in nested
171 // regions attached to the operations in the block).
172 Operation *userInTheBlock = nullptr;
173 for (Operation *user : value.getUsers()) {
174 userInTheBlock = block.findAncestorOpInBlock(*user);
175 if (userInTheBlock)
176 break;
177 }
178
179 // Values with zero users handled explicitly in the beginning, if the value
180 // is in live out set it must have at least one use in the block.
181 assert(userInTheBlock && "value must have a user in the block");
182
183 // Find the last user of the `value` in the block;
184 Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
185 assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
186 lastUsers.insert(lastUser);
187 }
188
189 // Process all the last users of the `value` inside each block where the value
190 // dies.
191 for (Operation *lastUser : lastUsers) {
192 // Return like operations forward reference count.
193 if (lastUser->hasTrait<OpTrait::ReturnLike>())
194 continue;
195
196 // We can't currently handle other types of terminators.
197 if (lastUser->hasTrait<OpTrait::IsTerminator>())
198 return lastUser->emitError() << "async reference counting can't handle "
199 "terminators that are not ReturnLike";
200
201 // Add a drop_ref immediately after the last user.
202 builder.setInsertionPointAfter(lastUser);
203 builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
204 }
205
206 // ------------------------------------------------------------------------ //
207 // Find blocks where the `value` is in `liveOut` set, however it is not in
208 // the `liveIn` set of all successors. If the `value` is not in the successor
209 // `liveIn` set, we add a `drop_ref` to the beginning of it.
210 // ------------------------------------------------------------------------ //
211
212 // Successors that we'll need a `drop_ref` for the `value`.
213 llvm::SmallSet<Block *, 4> dropRefSuccessors;
214
215 for (Block &block : definingRegion->getBlocks()) {
216 const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
217
218 // Skip the block if value is not in the `liveOut` set.
219 if (!blockLiveness->isLiveOut(value))
220 continue;
221
222 // Find successors that do not have `value` in the `liveIn` set.
223 for (Block *successor : block.getSuccessors()) {
224 const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
225
226 if (!succLiveness->isLiveIn(value))
227 dropRefSuccessors.insert(successor);
228 }
229 }
230
231 // Drop reference in all successor blocks that do not have the `value` in
232 // their `liveIn` set.
233 for (Block *dropRefSuccessor : dropRefSuccessors) {
234 builder.setInsertionPointToStart(dropRefSuccessor);
235 builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
236 }
237
238 // ------------------------------------------------------------------------ //
239 // Find all `async.execute` operation that take `value` as an operand
240 // (dependency token or async value), or capture implicitly by the nested
241 // region. Each `async.execute` operation will require `add_ref` operation
242 // to keep all captured values alive until it will finish its execution.
243 // ------------------------------------------------------------------------ //
244
245 llvm::SmallSet<ExecuteOp, 4> executeOperations;
246
247 auto trackAsyncExecute = [&](Operation *op) {
248 if (auto execute = dyn_cast<ExecuteOp>(op))
249 executeOperations.insert(execute);
250 };
251
252 for (Operation *user : value.getUsers()) {
253 // Follow parent operations up until the operation in the `definingRegion`.
254 while (user->getParentRegion() != definingRegion) {
255 trackAsyncExecute(user);
256 user = user->getParentOp();
257 assert(user != nullptr && "value user lies outside of the value region");
258 }
259
260 // Don't forget to process the parent in the `definingRegion` (can be the
261 // original user operation itself).
262 trackAsyncExecute(user);
263 }
264
265 // Process all `async.execute` operations capturing `value`.
266 for (ExecuteOp execute : executeOperations) {
267 // Add a reference before the execute operation to keep the reference
268 // counted alive before the async region completes execution.
269 builder.setInsertionPoint(execute.getOperation());
270 builder.create<AddRefOp>(loc, value, IntegerAttr::get(i32, 1));
271
272 // Drop the reference inside the async region before completion.
273 OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
274 executeBuilder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
275 }
276
277 return success();
278 }
279
runOnFunction()280 void AsyncRefCountingPass::runOnFunction() {
281 FuncOp func = getFunction();
282
283 // Check that we do not have explicit `add_ref` or `drop_ref` in the IR
284 // because otherwise automatic reference counting will produce incorrect
285 // results.
286 WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
287 if (isa<AddRefOp, DropRefOp>(op))
288 return op->emitError() << "explicit reference counting is not supported";
289 return WalkResult::advance();
290 });
291
292 if (refCountingWalk.wasInterrupted())
293 signalPassFailure();
294
295 // Add reference counting to block arguments.
296 WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
297 for (BlockArgument arg : block->getArguments())
298 if (isRefCounted(arg.getType()))
299 if (failed(addAutomaticRefCounting(arg)))
300 return WalkResult::interrupt();
301
302 return WalkResult::advance();
303 });
304
305 if (blockWalk.wasInterrupted())
306 signalPassFailure();
307
308 // Add reference counting to operation results.
309 WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
310 for (unsigned i = 0; i < op->getNumResults(); ++i)
311 if (isRefCounted(op->getResultTypes()[i]))
312 if (failed(addAutomaticRefCounting(op->getResult(i))))
313 return WalkResult::interrupt();
314
315 return WalkResult::advance();
316 });
317
318 if (opWalk.wasInterrupted())
319 signalPassFailure();
320 }
321
createAsyncRefCountingPass()322 std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
323 return std::make_unique<AsyncRefCountingPass>();
324 }
325