1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
33
34 namespace mlir {
35 namespace TFDevice {
36
37 namespace {
38
39 struct MergeControlFlowPass
40 : public TF::MergeControlFlowPassBase<MergeControlFlowPass> {
41 void runOnOperation() override;
42 };
43
44 // Gets the IfRegion op and all of ops in the then and else branches.
GetAllOpsFromIf(TF::IfRegionOp if_op)45 llvm::SmallSetVector<Operation*, 4> GetAllOpsFromIf(TF::IfRegionOp if_op) {
46 llvm::SmallSetVector<Operation*, 4> all_ops;
47 all_ops.insert(if_op);
48 for (Operation& op : if_op.then_branch().front()) {
49 all_ops.insert(&op);
50 }
51 for (Operation& op : if_op.else_branch().front()) {
52 all_ops.insert(&op);
53 }
54 return all_ops;
55 }
56
57 // Returns whether it is safe to merge `source` IfRegion into `destination`
58 // IfRegion. `source` must come after `destination`.
SafeToMerge(TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)59 bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination,
60 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
61 // IfRegion ops must be in the same block.
62 if (source.getOperation()->getBlock() !=
63 destination.getOperation()->getBlock())
64 return false;
65 assert(destination.getOperation()->isBeforeInBlock(source.getOperation()));
66
67 llvm::SmallSetVector<Operation*, 4> source_ops = GetAllOpsFromIf(source);
68 llvm::SmallSetVector<Operation*, 4> destination_ops =
69 GetAllOpsFromIf(destination);
70
71 // If there is an intermediate data or side effect dependency between the
72 // ops in destination and the ops in the source, it's not safe to merge
73 // them.
74 std::vector<Operation*> dependencies;
75 for (auto* user : destination.getOperation()->getUsers()) {
76 if (!source_ops.contains(user)) dependencies.push_back(user);
77 }
78 for (auto* successor : side_effect_analysis.DirectControlSuccessors(
79 destination.getOperation())) {
80 if (!source_ops.contains(successor)) dependencies.push_back(successor);
81 }
82 for (Operation& op : destination.then_branch().front()) {
83 for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
84 if (!source_ops.contains(successor) &&
85 !destination_ops.contains(successor))
86 dependencies.push_back(successor);
87 }
88 }
89 for (Operation& op : destination.else_branch().front()) {
90 for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
91 if (!source_ops.contains(successor) &&
92 !destination_ops.contains(successor))
93 dependencies.push_back(successor);
94 }
95 }
96
97 bool safe_to_merge = true;
98
99 llvm::SmallPtrSet<Operation*, 4> visited;
100 while (!dependencies.empty()) {
101 Operation* dependency = dependencies.back();
102 dependencies.pop_back();
103 if (visited.count(dependency)) continue;
104 visited.insert(dependency);
105 for (auto* user : dependency->getUsers()) {
106 if (source_ops.contains(user)) {
107 safe_to_merge = false;
108 break;
109 } else {
110 dependencies.push_back(user);
111 }
112 }
113 for (auto* successor :
114 side_effect_analysis.DirectControlSuccessors(dependency)) {
115 if (source_ops.contains(successor)) {
116 safe_to_merge = false;
117 break;
118 } else {
119 dependencies.push_back(successor);
120 }
121 }
122 // If the op is nested, then also consider the users and successors of the
123 // parent op.
124 if (dependency->getBlock() != destination.getOperation()->getBlock())
125 dependencies.push_back(dependency->getParentOp());
126 if (!safe_to_merge) break;
127 }
128 return safe_to_merge;
129 }
130
131 // Checks whether a return indice should be kep for `first_if_op` by checking
132 // for results in `second_if_op`.
GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,TF::IfRegionOp second_if_op)133 llvm::SmallVector<int, 4> GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,
134 TF::IfRegionOp second_if_op) {
135 llvm::SmallVector<int, 4> return_indices_to_keep;
136 for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) {
137 if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) {
138 return second_if_op->isProperAncestor(op);
139 })) {
140 return_indices_to_keep.push_back(index_and_value.index());
141 }
142 }
143 return return_indices_to_keep;
144 }
145
146 // Move the body excluding the terminators of else and then regions from
147 // 'source' to 'destination'.
MoveBranches(TF::IfRegionOp source,TF::IfRegionOp destination)148 void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
149 Block& destination_then_block = destination.then_branch().front();
150 auto& source_then_body = source.then_branch().front().getOperations();
151 destination_then_block.getOperations().splice(
152 destination_then_block.without_terminator().end(), source_then_body,
153 source_then_body.begin(), std::prev(source_then_body.end()));
154
155 Block& destination_else_block = destination.else_branch().front();
156 auto& source_else_body = source.else_branch().front().getOperations();
157 destination_else_block.getOperations().splice(
158 destination_else_block.without_terminator().end(), source_else_body,
159 source_else_body.begin(), std::prev(source_else_body.end()));
160 }
161
162 // Move all ops that depends on the results from `result_op` after `after_op`.
MoveResultsAfter(Operation * result_op,Operation * after_op,const TF::SideEffectAnalysis::Info & side_effect_analysis)163 void MoveResultsAfter(
164 Operation* result_op, Operation* after_op,
165 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
166 std::queue<Operation*> queue;
167
168 auto enqueue_deps = [&](Operation* source_op) {
169 for (Operation* user : source_op->getUsers()) {
170 queue.push(user);
171 }
172 source_op->walk([&](Operation* walked_op) {
173 for (Operation* successor :
174 side_effect_analysis.DirectControlSuccessors(walked_op)) {
175 if (!source_op->isProperAncestor(successor)) queue.push(successor);
176 }
177 });
178 };
179 enqueue_deps(result_op);
180
181 while (!queue.empty()) {
182 auto* op = queue.front();
183 queue.pop();
184 while (op->getBlock() != after_op->getBlock()) op = op->getParentOp();
185 if (op->isBeforeInBlock(after_op)) {
186 op->moveAfter(after_op);
187 after_op = op;
188 enqueue_deps(op);
189 }
190 }
191 }
192
CreateMergedIf(ArrayRef<int> source_return_indices_to_keep,ArrayRef<int> destination_return_indices_to_keep,TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)193 TF::IfRegionOp CreateMergedIf(
194 ArrayRef<int> source_return_indices_to_keep,
195 ArrayRef<int> destination_return_indices_to_keep, TF::IfRegionOp source,
196 TF::IfRegionOp destination,
197 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
198 llvm::SmallVector<Type, 4> merged_return_types;
199 for (int i : destination_return_indices_to_keep)
200 merged_return_types.push_back(destination.getResult(i).getType());
201 for (int i : source_return_indices_to_keep)
202 merged_return_types.push_back(source.getResult(i).getType());
203
204 OpBuilder builder(destination);
205 // Create new IfRegion with correct merged results.
206 builder.setInsertionPoint(source.getOperation());
207
208 auto new_if_op = builder.create<TF::IfRegionOp>(
209 destination.getLoc(), merged_return_types, destination.cond(),
210 destination.is_stateless() && source.is_stateless(),
211 destination._then_func_nameAttr(), destination._else_func_nameAttr());
212 new_if_op.then_branch().push_back(new Block);
213 new_if_op.else_branch().push_back(new Block);
214 // Replace internal usages of merged if ops.
215 for (OpResult result : destination.getResults()) {
216 replaceAllUsesInRegionWith(
217 result,
218 destination.then_branch().front().getTerminator()->getOperand(
219 result.getResultNumber()),
220 source.then_branch());
221 replaceAllUsesInRegionWith(
222 result,
223 destination.else_branch().front().getTerminator()->getOperand(
224 result.getResultNumber()),
225 source.else_branch());
226 }
227
228 MoveResultsAfter(destination.getOperation(), new_if_op.getOperation(),
229 side_effect_analysis);
230
231 // Replace external usages of merged if ops.
232 int new_return_index = 0;
233 for (int i : destination_return_indices_to_keep) {
234 destination.getResult(i).replaceAllUsesWith(
235 new_if_op.getResult(new_return_index++));
236 }
237 for (int i : source_return_indices_to_keep) {
238 source.getResult(i).replaceAllUsesWith(
239 new_if_op.getResult(new_return_index++));
240 }
241
242 // Create the Yield ops for both branches with merged results.
243 llvm::SmallVector<Value, 4> merged_then_yield_values;
244 for (int i : destination_return_indices_to_keep)
245 merged_then_yield_values.push_back(
246 destination.then_branch().front().getTerminator()->getOperand(i));
247 for (int i : source_return_indices_to_keep)
248 merged_then_yield_values.push_back(
249 source.then_branch().front().getTerminator()->getOperand(i));
250 builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
251 builder.create<TF::YieldOp>(
252 destination.then_branch().front().getTerminator()->getLoc(),
253 /*operands=*/merged_then_yield_values);
254
255 llvm::SmallVector<Value, 4> merged_else_yield_values;
256 for (int i : destination_return_indices_to_keep)
257 merged_else_yield_values.push_back(
258 destination.else_branch().front().getTerminator()->getOperand(i));
259 for (int i : source_return_indices_to_keep)
260 merged_else_yield_values.push_back(
261 source.else_branch().front().getTerminator()->getOperand(i));
262 builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
263 builder.create<TF::YieldOp>(
264 destination.else_branch().front().getTerminator()->getLoc(),
265 /*operands=*/merged_else_yield_values);
266
267 // Merge the two branch regions from both IfRegionOps into new IfRegionOp.
268 MoveBranches(/*source=*/destination, /*destination=*/new_if_op);
269 destination.erase();
270 MoveBranches(/*source=*/source, /*destination=*/new_if_op);
271 source.erase();
272 return new_if_op;
273 }
274
275 // Groups if regions by common predicate and attemps to merge them.
OptimizeIfRegions(Block * block,ModuleOp module)276 void OptimizeIfRegions(Block* block, ModuleOp module) {
277 // Determine IfRegions with the same predicate.
278 llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
279 grouped_if_ops;
280 block->walk([&](TF::IfRegionOp if_op) {
281 auto it = grouped_if_ops.try_emplace(if_op.cond());
282 it.first->getSecond().push_back(if_op);
283 });
284
285 auto side_effect_analysis = std::make_unique<TF::SideEffectAnalysis>(module);
286
287 for (auto& entry : grouped_if_ops) {
288 auto& if_ops = entry.second;
289 for (auto it = if_ops.begin(); it != if_ops.end(); ++it) {
290 TF::IfRegionOp first_if_op = *it;
291 for (auto it2 = std::next(it); it2 != if_ops.end(); ++it2) {
292 FuncOp func = first_if_op->getParentOfType<FuncOp>();
293 const TF::SideEffectAnalysis::Info& analysis =
294 side_effect_analysis->GetAnalysisForFunc(func);
295
296 TF::IfRegionOp second_if_op = *it2;
297 if (!SafeToMerge(second_if_op, first_if_op, analysis)) break;
298
299 // For both check if there are uses outside of IfRegion, keep these as
300 // part of the return and replace the internal uses.
301 auto first_return_indices_to_keep =
302 GetReturnIndicesToKeep(first_if_op, second_if_op);
303 auto second_return_indices_to_keep =
304 GetReturnIndicesToKeep(second_if_op, first_if_op);
305
306 auto new_if_op = CreateMergedIf(second_return_indices_to_keep,
307 first_return_indices_to_keep,
308 second_if_op, first_if_op, analysis);
309
310 if_ops.erase(it2--);
311 first_if_op = new_if_op;
312 // We regenerate the side effect analysis since merging the IfRegions
313 // invalidates the side effect analysis. This approach is O(N*M) where
314 // N is the number of ops in `module` and M is the number of pairs of
315 // IfRegion ops that are merged.
316 side_effect_analysis = std::make_unique<TF::SideEffectAnalysis>(module);
317 }
318 }
319 }
320 }
321
runOnOperation()322 void MergeControlFlowPass::runOnOperation() {
323 ModuleOp module = getOperation();
324 auto result = module.walk([&](tf_device::ClusterOp cluster) {
325 OptimizeIfRegions(&cluster.GetBody(), module);
326 return WalkResult::advance();
327 });
328
329 if (result.wasInterrupted()) return signalPassFailure();
330 }
331
332 } // namespace
333
CreateMergeControlFlowPass()334 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
335 return std::make_unique<MergeControlFlowPass>();
336 }
337
338 } // namespace TFDevice
339 } // namespace mlir
340