1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.h"
17
18 #include <utility>
19
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/EpochTracker.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "mlir/IR/PatternMatch.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
29 #include "tensorflow/core/ir/ops.h"
30 #include "tensorflow/core/ir/utility.h"
31 #include "tensorflow/core/transforms/pass_detail.h"
32 #include "tensorflow/core/transforms/utils/utils.h"
33
34 // Define the debug label used by `LLVM_DEBUG`.
35 #define DEBUG_TYPE "uncapture-region"
36
37 namespace mlir {
38 namespace tfg {
39
40 // Given a range of elements, this function returns a vector of elements
41 // excluding the ones whose index is contained in a bit vector.
42 template <typename RangeT>
FilterByIndex(RangeT && range,const llvm::BitVector & indices)43 static SmallVector<llvm::detail::ValueOfRange<RangeT>> FilterByIndex(
44 RangeT &&range, const llvm::BitVector &indices) {
45 SmallVector<llvm::detail::ValueOfRange<RangeT>> result;
46 for (auto &it : llvm::enumerate(range))
47 if (!indices.test(it.index())) result.push_back(it.value());
48 return result;
49 }
50
51 // Given a region, return the indices of arguments that are passthrough.
GetPassthroughIndices(Region & region)52 static SmallVector<unsigned> GetPassthroughIndices(Region ®ion) {
53 Block::BlockArgListType iter_args = GetLoopRegionDataArgs(region);
54
55 // Skip the index argument for `For` ops.
56 if (isa<ForRegionOp>(region.getParentOp()))
57 iter_args = iter_args.drop_front();
58
59 SmallVector<unsigned> indices;
60 auto yield = cast<YieldOp>(region.front().getTerminator());
61 for (auto it : llvm::zip(llvm::enumerate(yield.args()), iter_args)) {
62 if (std::get<0>(it).value() == std::get<1>(it))
63 indices.push_back(std::get<0>(it).index());
64 }
65 return indices;
66 }
67
68 // Given a terminator, erase the iteration arguments at the specified index.
EraseIterArgsFromTerminator(Operation * terminator,ArrayRef<unsigned> indices)69 static void EraseIterArgsFromTerminator(Operation *terminator,
70 ArrayRef<unsigned> indices) {
71 if (isa<YieldOp>(terminator)) {
72 util::SizedOperandSegmentsEraseOperands(terminator, indices);
73 return;
74 }
75
76 // Skip the condition operand.
77 assert(isa<ConditionOp>(terminator));
78 SmallVector<unsigned> cond_indices = llvm::to_vector(indices);
79 for (unsigned &index : cond_indices) ++index;
80 util::SizedOperandSegmentsEraseOperands(terminator, cond_indices);
81 }
82
83 namespace {
84 template <typename ConcreteT, typename OpT>
85 struct EliminatePassthroughIterArgs {
86 // Eliminate the passthrough iteration arguments for the given op. Returns the
87 // number of eliminated arguments.
Runmlir::tfg::__anond15d86580111::EliminatePassthroughIterArgs88 static size_t Run(OpT op, IRRewriter &rewriter) {
89 SmallVector<unsigned> indices = GetPassthroughIndices(op.body_region());
90 if (indices.empty()) return 0;
91
92 LLVM_DEBUG(llvm::dbgs()
93 << "Number of captures to erase: " << indices.size() << "\n");
94 // We need to:
95 // 1. remove the terminator operands
96 // 2. remove the operands and corresponding results (and replace them)
97 // 3. remove the block arguments (and update preserved attributes)
98 llvm::BitVector remove(op.init().size());
99 for (unsigned index : indices) remove.set(index);
100
101 for (Region ®ion : op->getRegions())
102 EraseIterArgsFromTerminator(region.front().getTerminator(), indices);
103
104 OpT new_op = ConcreteT::RebuildOp(remove, op, rewriter);
105 util::ForwardNonIntrinsicAttributes(op, new_op);
106
107 // Replace uses of each passthrough argument with the implicit capture
108 // value and remove the argument. Insert the implicitly captured value into
109 // the result list to replace the removed results from the original op.
110 SmallVector<Value> results = llvm::to_vector(ValueRange(new_op.outs()));
111 for (auto &it : llvm::enumerate(indices)) {
112 unsigned idx = it.value() - it.index();
113 Value data = op.init()[it.value()];
114 results.insert(results.begin() + it.value(), data);
115 ConcreteT::ReplaceArguments(idx, new_op, data,
116 LookupControlDependency(data));
117 }
118 results.push_back(new_op.ctl());
119 rewriter.replaceOp(op, results);
120 return indices.size();
121 }
122 };
123
124 struct EliminateForPassthroughIterArgs
125 : public EliminatePassthroughIterArgs<EliminateForPassthroughIterArgs,
126 ForRegionOp> {
RebuildOpmlir::tfg::__anond15d86580111::EliminateForPassthroughIterArgs127 static ForRegionOp RebuildOp(const llvm::BitVector &indices, ForRegionOp op,
128 IRRewriter &rewriter) {
129 rewriter.setInsertionPoint(op);
130 auto new_op = rewriter.create<ForRegionOp>(
131 op.getLoc(), FilterByIndex(op.outs().getTypes(), indices),
132 op.ctl().getType(), op.start(), op.limit(), op.delta(),
133 FilterByIndex(op.init(), indices), op.ctls(), op.body_attrsAttr(),
134 op.region_attrsAttr());
135 new_op.body_region().takeBody(op.body_region());
136 return new_op;
137 }
138
ReplaceArgumentsmlir::tfg::__anond15d86580111::EliminateForPassthroughIterArgs139 static void ReplaceArguments(unsigned index, ForRegionOp op, Value data,
140 Value ctl) {
141 // Argument indexing starts from 1 (skip the loop index argument).
142 GetLoopRegionDataArgs(op.body_region())[index + 1].replaceAllUsesWith(data);
143 GetLoopRegionControlTokens(op.body_region())[index + 1].replaceAllUsesWith(
144 ctl);
145 util::LoopRegionEraseArgument(op.body_region(), index + 1);
146 util::LoopRegionResultErased(op.body_region(), index);
147 }
148 };
149
150 template <typename WhileLikeRegionOp>
151 struct EliminateWhileLikePassthroughIterArgs
152 : public EliminatePassthroughIterArgs<
153 EliminateWhileLikePassthroughIterArgs<WhileLikeRegionOp>,
154 WhileLikeRegionOp> {
RebuildOpmlir::tfg::__anond15d86580111::EliminateWhileLikePassthroughIterArgs155 static WhileLikeRegionOp RebuildOp(const llvm::BitVector &indices,
156 WhileLikeRegionOp op,
157 IRRewriter &rewriter) {
158 rewriter.setInsertionPoint(op);
159 auto new_op = rewriter.create<WhileLikeRegionOp>(
160 op.getLoc(), FilterByIndex(op.outs().getTypes(), indices),
161 op.ctl().getType(), FilterByIndex(op.init(), indices), op.ctls(),
162 op.parallel_iterationsAttr(), op.cond_attrsAttr(), op.body_attrsAttr(),
163 op.cond_region_attrsAttr(), op.body_region_attrsAttr());
164 new_op.cond_region().takeBody(op.cond_region());
165 new_op.body_region().takeBody(op.body_region());
166 return new_op;
167 }
168
ReplaceArgumentsmlir::tfg::__anond15d86580111::EliminateWhileLikePassthroughIterArgs169 static void ReplaceArguments(unsigned index, WhileLikeRegionOp op, Value data,
170 Value ctl) {
171 // The while loop's condition function only has one result: the condition.
172 // So there are no preserved attributes to delete when removing an iteration
173 // argument.
174 GetLoopRegionDataArgs(op.cond_region())[index].replaceAllUsesWith(data);
175 GetLoopRegionControlTokens(op.cond_region())[index].replaceAllUsesWith(ctl);
176 util::LoopRegionEraseArgument(op.cond_region(), index);
177
178 GetLoopRegionDataArgs(op.body_region())[index].replaceAllUsesWith(data);
179 GetLoopRegionControlTokens(op.body_region())[index].replaceAllUsesWith(ctl);
180 util::LoopRegionEraseArgument(op.body_region(), index);
181 util::LoopRegionResultErased(op.body_region(), index);
182 }
183 };
184
185 struct EliminatePassthroughIterArgsPass
186 : public EliminatePassthroughIterArgsBase<
187 EliminatePassthroughIterArgsPass> {
runOnOperationmlir::tfg::__anond15d86580111::EliminatePassthroughIterArgsPass188 void runOnOperation() override {
189 IRRewriter rewriter(&getContext());
190 getOperation()->walk([&](Operation *op) {
191 if (auto for_op = dyn_cast<ForRegionOp>(op)) {
192 EliminateForPassthroughIterArgs::Run(for_op, rewriter);
193 } else if (auto while_op = dyn_cast<WhileRegionOp>(op)) {
194 EliminateWhileLikePassthroughIterArgs<WhileRegionOp>::Run(while_op,
195 rewriter);
196 } else if (auto while_op = dyn_cast<StatelessWhileRegionOp>(op)) {
197 EliminateWhileLikePassthroughIterArgs<StatelessWhileRegionOp>::Run(
198 while_op, rewriter);
199 } else if (auto while_op = dyn_cast<StatefulWhileRegionOp>(op)) {
200 EliminateWhileLikePassthroughIterArgs<StatefulWhileRegionOp>::Run(
201 while_op, rewriter);
202 }
203 });
204 }
205 };
206 } // namespace
207
CreateEliminatePassthroughIterArgsPass()208 std::unique_ptr<Pass> CreateEliminatePassthroughIterArgsPass() {
209 return std::make_unique<EliminatePassthroughIterArgsPass>();
210 }
211
212 } // namespace tfg
213 } // namespace mlir
214