• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.h"
17 
18 #include <utility>
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/EpochTracker.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
29 #include "tensorflow/core/ir/ops.h"
30 #include "tensorflow/core/ir/utility.h"
31 #include "tensorflow/core/transforms/pass_detail.h"
32 #include "tensorflow/core/transforms/utils/utils.h"
33 
34 // Define the debug label used by `LLVM_DEBUG`.
35 #define DEBUG_TYPE "uncapture-region"
36 
37 namespace mlir {
38 namespace tfg {
39 
40 // Given a range of elements, this function returns a vector of elements
41 // excluding the ones whose index is contained in a bit vector.
42 template <typename RangeT>
FilterByIndex(RangeT && range,const llvm::BitVector & indices)43 static SmallVector<llvm::detail::ValueOfRange<RangeT>> FilterByIndex(
44     RangeT &&range, const llvm::BitVector &indices) {
45   SmallVector<llvm::detail::ValueOfRange<RangeT>> result;
46   for (auto &it : llvm::enumerate(range))
47     if (!indices.test(it.index())) result.push_back(it.value());
48   return result;
49 }
50 
51 // Given a region, return the indices of arguments that are passthrough.
GetPassthroughIndices(Region & region)52 static SmallVector<unsigned> GetPassthroughIndices(Region &region) {
53   Block::BlockArgListType iter_args = GetLoopRegionDataArgs(region);
54 
55   // Skip the index argument for `For` ops.
56   if (isa<ForRegionOp>(region.getParentOp()))
57     iter_args = iter_args.drop_front();
58 
59   SmallVector<unsigned> indices;
60   auto yield = cast<YieldOp>(region.front().getTerminator());
61   for (auto it : llvm::zip(llvm::enumerate(yield.args()), iter_args)) {
62     if (std::get<0>(it).value() == std::get<1>(it))
63       indices.push_back(std::get<0>(it).index());
64   }
65   return indices;
66 }
67 
68 // Given a terminator, erase the iteration arguments at the specified index.
EraseIterArgsFromTerminator(Operation * terminator,ArrayRef<unsigned> indices)69 static void EraseIterArgsFromTerminator(Operation *terminator,
70                                         ArrayRef<unsigned> indices) {
71   if (isa<YieldOp>(terminator)) {
72     util::SizedOperandSegmentsEraseOperands(terminator, indices);
73     return;
74   }
75 
76   // Skip the condition operand.
77   assert(isa<ConditionOp>(terminator));
78   SmallVector<unsigned> cond_indices = llvm::to_vector(indices);
79   for (unsigned &index : cond_indices) ++index;
80   util::SizedOperandSegmentsEraseOperands(terminator, cond_indices);
81 }
82 
83 namespace {
84 template <typename ConcreteT, typename OpT>
85 struct EliminatePassthroughIterArgs {
86   // Eliminate the passthrough iteration arguments for the given op. Returns the
87   // number of eliminated arguments.
Runmlir::tfg::__anond15d86580111::EliminatePassthroughIterArgs88   static size_t Run(OpT op, IRRewriter &rewriter) {
89     SmallVector<unsigned> indices = GetPassthroughIndices(op.body_region());
90     if (indices.empty()) return 0;
91 
92     LLVM_DEBUG(llvm::dbgs()
93                << "Number of captures to erase: " << indices.size() << "\n");
94     // We need to:
95     // 1. remove the terminator operands
96     // 2. remove the operands and corresponding results (and replace them)
97     // 3. remove the block arguments (and update preserved attributes)
98     llvm::BitVector remove(op.init().size());
99     for (unsigned index : indices) remove.set(index);
100 
101     for (Region &region : op->getRegions())
102       EraseIterArgsFromTerminator(region.front().getTerminator(), indices);
103 
104     OpT new_op = ConcreteT::RebuildOp(remove, op, rewriter);
105     util::ForwardNonIntrinsicAttributes(op, new_op);
106 
107     // Replace uses of each passthrough argument with the implicit capture
108     // value and remove the argument. Insert the implicitly captured value into
109     // the result list to replace the removed results from the original op.
110     SmallVector<Value> results = llvm::to_vector(ValueRange(new_op.outs()));
111     for (auto &it : llvm::enumerate(indices)) {
112       unsigned idx = it.value() - it.index();
113       Value data = op.init()[it.value()];
114       results.insert(results.begin() + it.value(), data);
115       ConcreteT::ReplaceArguments(idx, new_op, data,
116                                   LookupControlDependency(data));
117     }
118     results.push_back(new_op.ctl());
119     rewriter.replaceOp(op, results);
120     return indices.size();
121   }
122 };
123 
124 struct EliminateForPassthroughIterArgs
125     : public EliminatePassthroughIterArgs<EliminateForPassthroughIterArgs,
126                                           ForRegionOp> {
RebuildOpmlir::tfg::__anond15d86580111::EliminateForPassthroughIterArgs127   static ForRegionOp RebuildOp(const llvm::BitVector &indices, ForRegionOp op,
128                                IRRewriter &rewriter) {
129     rewriter.setInsertionPoint(op);
130     auto new_op = rewriter.create<ForRegionOp>(
131         op.getLoc(), FilterByIndex(op.outs().getTypes(), indices),
132         op.ctl().getType(), op.start(), op.limit(), op.delta(),
133         FilterByIndex(op.init(), indices), op.ctls(), op.body_attrsAttr(),
134         op.region_attrsAttr());
135     new_op.body_region().takeBody(op.body_region());
136     return new_op;
137   }
138 
ReplaceArgumentsmlir::tfg::__anond15d86580111::EliminateForPassthroughIterArgs139   static void ReplaceArguments(unsigned index, ForRegionOp op, Value data,
140                                Value ctl) {
141     // Argument indexing starts from 1 (skip the loop index argument).
142     GetLoopRegionDataArgs(op.body_region())[index + 1].replaceAllUsesWith(data);
143     GetLoopRegionControlTokens(op.body_region())[index + 1].replaceAllUsesWith(
144         ctl);
145     util::LoopRegionEraseArgument(op.body_region(), index + 1);
146     util::LoopRegionResultErased(op.body_region(), index);
147   }
148 };
149 
150 template <typename WhileLikeRegionOp>
151 struct EliminateWhileLikePassthroughIterArgs
152     : public EliminatePassthroughIterArgs<
153           EliminateWhileLikePassthroughIterArgs<WhileLikeRegionOp>,
154           WhileLikeRegionOp> {
RebuildOpmlir::tfg::__anond15d86580111::EliminateWhileLikePassthroughIterArgs155   static WhileLikeRegionOp RebuildOp(const llvm::BitVector &indices,
156                                      WhileLikeRegionOp op,
157                                      IRRewriter &rewriter) {
158     rewriter.setInsertionPoint(op);
159     auto new_op = rewriter.create<WhileLikeRegionOp>(
160         op.getLoc(), FilterByIndex(op.outs().getTypes(), indices),
161         op.ctl().getType(), FilterByIndex(op.init(), indices), op.ctls(),
162         op.parallel_iterationsAttr(), op.cond_attrsAttr(), op.body_attrsAttr(),
163         op.cond_region_attrsAttr(), op.body_region_attrsAttr());
164     new_op.cond_region().takeBody(op.cond_region());
165     new_op.body_region().takeBody(op.body_region());
166     return new_op;
167   }
168 
ReplaceArgumentsmlir::tfg::__anond15d86580111::EliminateWhileLikePassthroughIterArgs169   static void ReplaceArguments(unsigned index, WhileLikeRegionOp op, Value data,
170                                Value ctl) {
171     // The while loop's condition function only has one result: the condition.
172     // So there are no preserved attributes to delete when removing an iteration
173     // argument.
174     GetLoopRegionDataArgs(op.cond_region())[index].replaceAllUsesWith(data);
175     GetLoopRegionControlTokens(op.cond_region())[index].replaceAllUsesWith(ctl);
176     util::LoopRegionEraseArgument(op.cond_region(), index);
177 
178     GetLoopRegionDataArgs(op.body_region())[index].replaceAllUsesWith(data);
179     GetLoopRegionControlTokens(op.body_region())[index].replaceAllUsesWith(ctl);
180     util::LoopRegionEraseArgument(op.body_region(), index);
181     util::LoopRegionResultErased(op.body_region(), index);
182   }
183 };
184 
185 struct EliminatePassthroughIterArgsPass
186     : public EliminatePassthroughIterArgsBase<
187           EliminatePassthroughIterArgsPass> {
runOnOperationmlir::tfg::__anond15d86580111::EliminatePassthroughIterArgsPass188   void runOnOperation() override {
189     IRRewriter rewriter(&getContext());
190     getOperation()->walk([&](Operation *op) {
191       if (auto for_op = dyn_cast<ForRegionOp>(op)) {
192         EliminateForPassthroughIterArgs::Run(for_op, rewriter);
193       } else if (auto while_op = dyn_cast<WhileRegionOp>(op)) {
194         EliminateWhileLikePassthroughIterArgs<WhileRegionOp>::Run(while_op,
195                                                                   rewriter);
196       } else if (auto while_op = dyn_cast<StatelessWhileRegionOp>(op)) {
197         EliminateWhileLikePassthroughIterArgs<StatelessWhileRegionOp>::Run(
198             while_op, rewriter);
199       } else if (auto while_op = dyn_cast<StatefulWhileRegionOp>(op)) {
200         EliminateWhileLikePassthroughIterArgs<StatefulWhileRegionOp>::Run(
201             while_op, rewriter);
202       }
203     });
204   }
205 };
206 }  // namespace
207 
CreateEliminatePassthroughIterArgsPass()208 std::unique_ptr<Pass> CreateEliminatePassthroughIterArgsPass() {
209   return std::make_unique<EliminatePassthroughIterArgsPass>();
210 }
211 
212 }  // namespace tfg
213 }  // namespace mlir
214