• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- AsyncRefCountingOptimization.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Async/IR/Async.h"
15 #include "mlir/Dialect/Async/Passes.h"
16 #include "llvm/ADT/SmallSet.h"
17 
18 using namespace mlir;
19 using namespace mlir::async;
20 
21 #define DEBUG_TYPE "async-ref-counting"
22 
23 namespace {
24 
25 class AsyncRefCountingOptimizationPass
26     : public AsyncRefCountingOptimizationBase<
27           AsyncRefCountingOptimizationPass> {
28 public:
29   AsyncRefCountingOptimizationPass() = default;
30   void runOnFunction() override;
31 
32 private:
33   LogicalResult optimizeReferenceCounting(Value value);
34 };
35 
36 } // namespace
37 
38 LogicalResult
optimizeReferenceCounting(Value value)39 AsyncRefCountingOptimizationPass::optimizeReferenceCounting(Value value) {
40   Region *definingRegion = value.getParentRegion();
41 
42   // Find all users of the `value` inside each block, including operations that
43   // do not use `value` directly, but have a direct use inside nested region(s).
44   //
45   // Example:
46   //
47   //  ^bb1:
48   //    %token = ...
49   //    scf.if %cond {
50   //      ^bb2:
51   //      async.await %token : !async.token
52   //    }
53   //
54   // %token has a use inside ^bb2 (`async.await`) and inside ^bb1 (`scf.if`).
55   //
56   // In addition to the operation that uses the `value` we also keep track if
57   // this user is an `async.execute` operation itself, or has `async.execute`
58   // operations in the nested regions that do use the `value`.
59 
60   struct UserInfo {
61     Operation *operation;
62     bool hasExecuteUser;
63   };
64 
65   struct BlockUsersInfo {
66     llvm::SmallVector<AddRefOp, 4> addRefs;
67     llvm::SmallVector<DropRefOp, 4> dropRefs;
68     llvm::SmallVector<UserInfo, 4> users;
69   };
70 
71   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
72 
73   auto updateBlockUsersInfo = [&](UserInfo user) {
74     BlockUsersInfo &info = blockUsers[user.operation->getBlock()];
75     info.users.push_back(user);
76 
77     if (auto addRef = dyn_cast<AddRefOp>(user.operation))
78       info.addRefs.push_back(addRef);
79     if (auto dropRef = dyn_cast<DropRefOp>(user.operation))
80       info.dropRefs.push_back(dropRef);
81   };
82 
83   for (Operation *user : value.getUsers()) {
84     bool isAsyncUser = isa<ExecuteOp>(user);
85 
86     while (user->getParentRegion() != definingRegion) {
87       updateBlockUsersInfo({user, isAsyncUser});
88       user = user->getParentOp();
89       isAsyncUser |= isa<ExecuteOp>(user);
90       assert(user != nullptr && "value user lies outside of the value region");
91     }
92 
93     updateBlockUsersInfo({user, isAsyncUser});
94   }
95 
96   // Sort all operations found in the block.
97   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
98     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
99       return a->isBeforeInBlock(b);
100     };
101     llvm::sort(info.addRefs, isBeforeInBlock);
102     llvm::sort(info.dropRefs, isBeforeInBlock);
103     llvm::sort(info.users, [&](UserInfo a, UserInfo b) -> bool {
104       return isBeforeInBlock(a.operation, b.operation);
105     });
106 
107     return info;
108   };
109 
110   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
111   // blocks that modify the reference count of the `value`.
112   for (auto &kv : blockUsers) {
113     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
114 
115     // Find all cancellable pairs first and erase them later to keep all
116     // pointers in the `info` valid until the end.
117     //
118     // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
119     llvm::SmallDenseMap<Operation *, Operation *> cancellable;
120 
121     for (AddRefOp addRef : info.addRefs) {
122       for (DropRefOp dropRef : info.dropRefs) {
123         // `drop_ref` operation after the `add_ref` with matching count.
124         if (dropRef.count() != addRef.count() ||
125             dropRef->isBeforeInBlock(addRef.getOperation()))
126           continue;
127 
128         // `drop_ref` was already marked for removal.
129         if (cancellable.find(dropRef.getOperation()) != cancellable.end())
130           continue;
131 
132         // Check `value` users between `addRef` and `dropRef` in the `block`.
133         Operation *addRefOp = addRef.getOperation();
134         Operation *dropRefOp = dropRef.getOperation();
135 
136         // If there is a "regular" user after the `async.execute` user it is
137         // unsafe to erase cancellable reference counting operations pair,
138         // because async region can complete before the "regular" user and
139         // destroy the reference counted value.
140         bool hasExecuteUser = false;
141         bool unsafeToCancel = false;
142 
143         for (UserInfo &user : info.users) {
144           Operation *op = user.operation;
145 
146           // `user` operation lies after `addRef` ...
147           if (op == addRefOp || op->isBeforeInBlock(addRefOp))
148             continue;
149           // ... and before `dropRef`.
150           if (op == dropRefOp || dropRefOp->isBeforeInBlock(op))
151             break;
152 
153           bool isRegularUser = !user.hasExecuteUser;
154           bool isExecuteUser = user.hasExecuteUser;
155 
156           // It is unsafe to cancel `addRef` / `dropRef` pair.
157           if (isRegularUser && hasExecuteUser) {
158             unsafeToCancel = true;
159             break;
160           }
161 
162           hasExecuteUser |= isExecuteUser;
163         }
164 
165         // Mark the pair of reference counting operations for removal.
166         if (!unsafeToCancel)
167           cancellable[dropRef.getOperation()] = addRef.getOperation();
168 
169         // If it us unsafe to cancel `addRef <-> dropRef` pair at this point,
170         // all the following pairs will be also unsafe.
171         break;
172       }
173     }
174 
175     // Erase all cancellable `addRef <-> dropRef` operation pairs.
176     for (auto &kv : cancellable) {
177       kv.first->erase();
178       kv.second->erase();
179     }
180   }
181 
182   return success();
183 }
184 
runOnFunction()185 void AsyncRefCountingOptimizationPass::runOnFunction() {
186   FuncOp func = getFunction();
187 
188   // Optimize reference counting for values defined by block arguments.
189   WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult {
190     for (BlockArgument arg : block->getArguments())
191       if (isRefCounted(arg.getType()))
192         if (failed(optimizeReferenceCounting(arg)))
193           return WalkResult::interrupt();
194 
195     return WalkResult::advance();
196   });
197 
198   if (blockWalk.wasInterrupted())
199     signalPassFailure();
200 
201   // Optimize reference counting for values defined by operation results.
202   WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult {
203     for (unsigned i = 0; i < op->getNumResults(); ++i)
204       if (isRefCounted(op->getResultTypes()[i]))
205         if (failed(optimizeReferenceCounting(op->getResult(i))))
206           return WalkResult::interrupt();
207 
208     return WalkResult::advance();
209   });
210 
211   if (opWalk.wasInterrupted())
212     signalPassFailure();
213 }
214 
215 std::unique_ptr<OperationPass<FuncOp>>
createAsyncRefCountingOptimizationPass()216 mlir::createAsyncRefCountingOptimizationPass() {
217   return std::make_unique<AsyncRefCountingOptimizationPass>();
218 }
219