1 //===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "llvm/ADT/SmallSet.h"
17
18 using namespace mlir;
19 using namespace mlir::async;
20
21 #define DEBUG_TYPE "async-ref-counting"
22
23 namespace {
24
25 class AsyncRefCountingOptimizationPass
26 : public AsyncRefCountingOptimizationBase<
27 AsyncRefCountingOptimizationPass> {
28 public:
29 AsyncRefCountingOptimizationPass() = default;
30 void runOnFunction() override;
31
32 private:
33 LogicalResult optimizeReferenceCounting(Value value);
34 };
35
36 } // namespace
37
38 LogicalResult
optimizeReferenceCounting(Value value)39 AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
40 Region *definingRegion = value.getParentRegion();
41
42 // Find all users of the `value` inside each block, including operations that
43 // do not use `value` directly, but have a direct use inside nested region(s).
44 //
45 // Example:
46 //
47 // ^bb1:
48 // %token = ...
49 // scf.if %cond {
50 // ^bb2:
51 // async.await %token : !async.token
52 // }
53 //
54 // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
55 //
56 // In addition to the operation that uses the `value` we also keep track if
57 // this user is an `async.execute` operation itself, or has `async.execute`
58 // operations in the nested regions that do use the `value`.
59
60 struct UserInfo {
61 Operation *operation;
62 bool hasExecuteUser;
63 };
64
65 struct BlockUsersInfo {
66 llvm::SmallVector<AddRefOp, 4> addRefs;
67 llvm::SmallVector<DropRefOp, 4> dropRefs;
68 llvm::SmallVector<UserInfo, 4> users;
69 };
70
71 llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
72
73 auto updateBlockUsersInfo = [&](UserInfo user) {
74 BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
75 info.users.push_back(user);
76
77 if (auto addRef = dyn_cast<AddRefOp>(user.operation))
78 info.addRefs.push_back(addRef);
79 if (auto dropRef = dyn_cast<DropRefOp>(user.operation))
80 info.dropRefs.push_back(dropRef);
81 };
82
83 for (Operation *user : value.getUsers()) {
84 bool isAsyncUser = isa<ExecuteOp>(user);
85
86 while (user->getParentRegion() != definingRegion) {
87 updateBlockUsersInfo({user, isAsyncUser});
88 user = user->getParentOp();
89 isAsyncUser |= isa<ExecuteOp>(user);
90 assert(user != nullptr && "value user lies outside of the value region");
91 }
92
93 updateBlockUsersInfo({user, isAsyncUser});
94 }
95
96 // Sort all operations found in the block.
97 auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
98 auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
99 return a->isBeforeInBlock(b);
100 };
101 llvm::sort(info.addRefs, isBeforeInBlock);
102 llvm::sort(info.dropRefs, isBeforeInBlock);
103 llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
104 return isBeforeInBlock(a.operation, b.operation);
105 });
106
107 return info;
108 };
109
110 // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
111 // blocks that modify the reference count of the `value`.
112 for (auto &kv : blockUsers) {
113 BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
114
115 // Find all cancellable pairs first and erase them later to keep all
116 // pointers in the `info` valid until the end.
117 //
118 // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
119 llvm::SmallDenseMap<Operation *, Operation *> cancellable;
120
121 for (AddRefOp addRef : info.addRefs) {
122 for (DropRefOp dropRef : info.dropRefs) {
123 // `drop_ref` operation after the `add_ref` with matching count.
124 if (dropRef.count() != addRef.count() ||
125 dropRef->isBeforeInBlock(addRef.getOperation()))
126 continue;
127
128 // `drop_ref` was already marked for removal.
129 if (cancellable.find(dropRef.getOperation()) != cancellable.end())
130 continue;
131
132 // Check `value` users between `addRef` and `dropRef` in the `block`.
133 Operation *addRefOp = addRef.getOperation();
134 Operation *dropRefOp = dropRef.getOperation();
135
136 // If there is a "regular" user after the `async.execute` user it is
137 // unsafe to erase cancellable reference counting operations pair,
138 // because async region can complete before the "regular" user and
139 // destroy the reference counted value.
140 bool hasExecuteUser = false;
141 bool unsafeToCancel = false;
142
143 for (UserInfo &user : info.users) {
144 Operation *op = user.operation;
145
146 // `user` operation lies after `addRef` ...
147 if (op == addRefOp || op->isBeforeInBlock(addRefOp))
148 continue;
149 // ... and before `dropRef`.
150 if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
151 break;
152
153 bool isRegularUser = !user.hasExecuteUser;
154 bool isExecuteUser = user.hasExecuteUser;
155
156 // It is unsafe to cancel `addRef` / `dropRef` pair.
157 if (isRegularUser && hasExecuteUser) {
158 unsafeToCancel = true;
159 break;
160 }
161
162 hasExecuteUser |= isExecuteUser;
163 }
164
165 // Mark the pair of reference counting operations for removal.
166 if (!unsafeToCancel)
167 cancellable[dropRef.getOperation()] = addRef.getOperation();
168
169 // If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
170 // all the following pairs will be also unsafe.
171 break;
172 }
173 }
174
175 // Erase all cancellable `addRef <-> dropRef` operation pairs.
176 for (auto &kv : cancellable) {
177 kv.first->erase();
178 kv.second->erase();
179 }
180 }
181
182 return success();
183 }
184
runOnFunction()185 void AsyncRefCountingOptimizationPass::runOnFunction() {
186 FuncOp func = getFunction();
187
188 // Optimize reference counting for values defined by block arguments.
189 WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
190 for (BlockArgument arg : block->getArguments())
191 if (isRefCounted(arg.getType()))
192 if (failed(optimizeReferenceCounting(arg)))
193 return WalkResult::interrupt();
194
195 return WalkResult::advance();
196 });
197
198 if (blockWalk.wasInterrupted())
199 signalPassFailure();
200
201 // Optimize reference counting for values defined by operation results.
202 WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
203 for (unsigned i = 0; i < op->getNumResults(); ++i)
204 if (isRefCounted(op->getResultTypes()[i]))
205 if (failed(optimizeReferenceCounting(op->getResult(i))))
206 return WalkResult::interrupt();
207
208 return WalkResult::advance();
209 });
210
211 if (opWalk.wasInterrupted())
212 signalPassFailure();
213 }
214
215 std::unique_ptr<OperationPass<FuncOp>>
createAsyncRefCountingOptimizationPass()216 mlir::createAsyncRefCountingOptimizationPass() {
217 return std::make_unique<AsyncRefCountingOptimizationPass>();
218 }
219