• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements automatic reference counting for Async dialect data
10 // types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Analysis/Liveness.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Async/Passes.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "llvm/ADT/SmallSet.h"
22 
23 using namespace mlir;
24 using namespace mlir::async;
25 
26 #define DEBUG_TYPE "async-ref-counting"
27 
28 namespace {
29 
30 class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> {
31 public:
32   AsyncRefCountingPass() = default;
33   void runOnFunction() override;
34 
35 private:
36   /// Adds an automatic reference counting to the `value`.
37   ///
38   /// All values are semantically created with a reference count of +1 and it is
39   /// the responsibility of the last async value user to drop reference count.
40   ///
41   /// Async values created when:
42   ///   1. Operation returns async result (e.g. the result of an
43   ///      `async.execute`).
44   ///   2. Async value passed in as a block argument.
45   ///
46   /// To implement automatic reference counting, we must insert a +1 reference
47   /// before each `async.execute` operation using the value, and drop it after
48   /// the last use inside the async body region (we currently drop the reference
49   /// before the `async.yield` terminator).
50   ///
51   /// Automatic reference counting algorithm outline:
52   ///
53   /// 1. `ReturnLike` operations forward the reference counted values without
54   ///     modifying the reference count.
55   ///
56   /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of
57   ///    reference counted values ends, and insert `drop_ref` operations after
58   ///    the last use of the value.
59   ///
60   /// 3. Insert `add_ref` before the `async.execute` operation capturing the
61   ///    value, and pairing `drop_ref` before the async body region terminator,
62   ///    to release the captured reference counted value when execution
63   ///    completes.
64   ///
65   /// 4. If the reference counted value is passed only to some of the block
66   ///    successors, insert `drop_ref` operations in the beginning of the blocks
67   ///    that do not have reference counted value uses.
68   ///
69   ///
70   /// Example:
71   ///
72   ///   %token = ...
73   ///   async.execute {
74   ///     async.await %token : !async.token   // await #1
75   ///     async.yield
76   ///   }
77   ///   async.await %token : !async.token     // await #2
78   ///
79   /// Based on the liveness analysis await #2 is the last use of the %token,
80   /// however the execution of the async region can be delayed, and to guarantee
81   /// that the %token is still alive when await #1 executes we need to
82   /// explicitly extend its lifetime using `add_ref` operation.
83   ///
84   /// After automatic reference counting:
85   ///
86   ///   %token = ...
87   ///
88   ///   // Make sure that %token is alive inside async.execute.
89   ///   async.add_ref %token {count = 1 : i32} : !async.token
90   ///
91   ///   async.execute {
92   ///     async.await %token : !async.token   // await #1
93   ///
94   ///     // Drop the extra reference added to keep %token alive.
95   ///     async.drop_ref %token {count = 1 : i32} : !async.token
96   ///
97   ///     async.yied
98   ///   }
99   ///   async.await %token : !async.token     // await #2
100   ///
101   ///   // Drop the reference after the last use of %token.
102   ///   async.drop_ref %token {count = 1 : i32} : !async.token
103   ///
104   LogicalResult addAutomaticRefCounting(Value value);
105 };
106 
107 } // namespace
108 
addAutomaticRefCounting(Value value)109 LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
110   MLIRContext *ctx = value.getContext();
111   OpBuilder builder(ctx);
112 
113   // Set inserton point after the operation producing a value, or at the
114   // beginning of the block if the value defined by the block argument.
115   if (Operation *op = value.getDefiningOp())
116     builder.setInsertionPointAfter(op);
117   else
118     builder.setInsertionPointToStart(value.getParentBlock());
119 
120   Location loc = value.getLoc();
121   auto i32 = IntegerType::get(32, ctx);
122 
123   // Drop the reference count immediately if the value has no uses.
124   if (value.getUses().empty()) {
125     builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
126     return success();
127   }
128 
129   // Use liveness analysis to find the placement of `drop_ref`operation.
130   auto liveness = getAnalysis<Liveness>();
131 
132   // We analyse only the blocks of the region that defines the `value`, and do
133   // not check nested blocks attached to operations.
134   //
135   // By analyzing only the `definingRegion` CFG we potentially loose an
136   // opportunity to drop the reference count earlier and can extend the lifetime
137   // of reference counted value longer then it is really required.
138   //
139   // We also assume that all nested regions finish their execution before the
140   // completion of the owner operation. The only exception to this rule is
141   // `async.execute` operation, which is handled explicitly below.
142   Region *definingRegion = value.getParentRegion();
143 
144   // ------------------------------------------------------------------------ //
145   // Find blocks where the `value` dies: the value is in `liveIn` set and not
146   // in the `liveOut` set. We place `drop_ref` immediately after the last use
147   // of the `value` in such regions.
148   // ------------------------------------------------------------------------ //
149 
150   // Last users of the `value` inside all blocks where the value dies.
151   llvm::SmallSet<Operation *, 4> lastUsers;
152 
153   for (Block &block : definingRegion->getBlocks()) {
154     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
155 
156     // Value in live input set or was defined in the block.
157     bool liveIn = blockLiveness->isLiveIn(value) ||
158                   blockLiveness->getBlock() == value.getParentBlock();
159     if (!liveIn)
160       continue;
161 
162     // Value is in the live out set.
163     bool liveOut = blockLiveness->isLiveOut(value);
164     if (liveOut)
165       continue;
166 
167     // We proved that `value` dies in the `block`. Now find the last use of the
168     // `value` inside the `block`.
169 
170     // Find any user of the `value` inside the block (including uses in nested
171     // regions attached to the operations in the block).
172     Operation *userInTheBlock = nullptr;
173     for (Operation *user : value.getUsers()) {
174       userInTheBlock = block.findAncestorOpInBlock(*user);
175       if (userInTheBlock)
176         break;
177     }
178 
179     // Values with zero users handled explicitly in the beginning, if the value
180     // is in live out set it must have at least one use in the block.
181     assert(userInTheBlock && "value must have a user in the block");
182 
183     // Find the last user of the `value` in the block;
184     Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
185     assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
186     lastUsers.insert(lastUser);
187   }
188 
189   // Process all the last users of the `value` inside each block where the value
190   // dies.
191   for (Operation *lastUser : lastUsers) {
192     // Return like operations forward reference count.
193     if (lastUser->hasTrait<OpTrait::ReturnLike>())
194       continue;
195 
196     // We can't currently handle other types of terminators.
197     if (lastUser->hasTrait<OpTrait::IsTerminator>())
198       return lastUser->emitError() << "async reference counting can't handle "
199                                       "terminators that are not ReturnLike";
200 
201     // Add a drop_ref immediately after the last user.
202     builder.setInsertionPointAfter(lastUser);
203     builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
204   }
205 
206   // ------------------------------------------------------------------------ //
207   // Find blocks where the `value` is in `liveOut` set, however it is not in
208   // the `liveIn` set of all successors. If the `value` is not in the successor
209   // `liveIn` set, we add a `drop_ref` to the beginning of it.
210   // ------------------------------------------------------------------------ //
211 
212   // Successors that we'll need a `drop_ref` for the `value`.
213   llvm::SmallSet<Block *, 4> dropRefSuccessors;
214 
215   for (Block &block : definingRegion->getBlocks()) {
216     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
217 
218     // Skip the block if value is not in the `liveOut` set.
219     if (!blockLiveness->isLiveOut(value))
220       continue;
221 
222     // Find successors that do not have `value` in the `liveIn` set.
223     for (Block *successor : block.getSuccessors()) {
224       const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
225 
226       if (!succLiveness->isLiveIn(value))
227         dropRefSuccessors.insert(successor);
228     }
229   }
230 
231   // Drop reference in all successor blocks that do not have the `value` in
232   // their `liveIn` set.
233   for (Block *dropRefSuccessor : dropRefSuccessors) {
234     builder.setInsertionPointToStart(dropRefSuccessor);
235     builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
236   }
237 
238   // ------------------------------------------------------------------------ //
239   // Find all `async.execute` operation that take `value` as an operand
240   // (dependency token or async value), or capture implicitly by the nested
241   // region. Each `async.execute` operation will require `add_ref` operation
242   // to keep all captured values alive until it will finish its execution.
243   // ------------------------------------------------------------------------ //
244 
245   llvm::SmallSet<ExecuteOp, 4> executeOperations;
246 
247   auto trackAsyncExecute = [&](Operation *op) {
248     if (auto execute = dyn_cast<ExecuteOp>(op))
249       executeOperations.insert(execute);
250   };
251 
252   for (Operation *user : value.getUsers()) {
253     // Follow parent operations up until the operation in the `definingRegion`.
254     while (user->getParentRegion() != definingRegion) {
255       trackAsyncExecute(user);
256       user = user->getParentOp();
257       assert(user != nullptr && "value user lies outside of the value region");
258     }
259 
260     // Don't forget to process the parent in the `definingRegion` (can be the
261     // original user operation itself).
262     trackAsyncExecute(user);
263   }
264 
265   // Process all `async.execute` operations capturing `value`.
266   for (ExecuteOp execute : executeOperations) {
267     // Add a reference before the execute operation to keep the reference
268     // counted alive before the async region completes execution.
269     builder.setInsertionPoint(execute.getOperation());
270     builder.create<AddRefOp>(loc, value, IntegerAttr::get(i32, 1));
271 
272     // Drop the reference inside the async region before completion.
273     OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody());
274     executeBuilder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1));
275   }
276 
277   return success();
278 }
279 
runOnFunction()280 void AsyncRefCountingPass::runOnFunction() {
281   FuncOp func = getFunction();
282 
283   // Check that we do not have explicit `add_ref` or `drop_ref` in the IR
284   // because otherwise automatic reference counting will produce incorrect
285   // results.
286   WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult {
287     if (isa<AddRefOp, DropRefOp>(op))
288       return op->emitError() << "explicit reference counting is not supported";
289     return WalkResult::advance();
290   });
291 
292   if (refCountingWalk.wasInterrupted())
293     signalPassFailure();
294 
295   // Add reference counting to block arguments.
296   WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
297     for (BlockArgument arg : block->getArguments())
298       if (isRefCounted(arg.getType()))
299         if (failed(addAutomaticRefCounting(arg)))
300           return WalkResult::interrupt();
301 
302     return WalkResult::advance();
303   });
304 
305   if (blockWalk.wasInterrupted())
306     signalPassFailure();
307 
308   // Add reference counting to operation results.
309   WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
310     for (unsigned i = 0; i < op->getNumResults(); ++i)
311       if (isRefCounted(op->getResultTypes()[i]))
312         if (failed(addAutomaticRefCounting(op->getResult(i))))
313           return WalkResult::interrupt();
314 
315     return WalkResult::advance();
316   });
317 
318   if (opWalk.wasInterrupted())
319     signalPassFailure();
320 }
321 
createAsyncRefCountingPass()322 std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() {
323   return std::make_unique<AsyncRefCountingPass>();
324 }
325