1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
23 #include "mlir/Pass/Pass.h" // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
32
33 namespace mlir {
34 namespace TFDevice {
35
36 namespace {
37
38 // This pass merges IfRegion ops together if they have the same predicate and it
39 // is safe to do so (there are no intermediate dependencies, they are in the
40 // same block, etc).
41 //
42 // A simple example:
43 // "tf.IfRegion"(%0) ( {
44 // %2 = "tf.A"() : () -> (tensor<f32>)
45 // "tf.Yield"() : () -> ()
46 // }, {
47 // "tf.Yield"() : () -> ()
48 // }) { is_stateless = true } : (tensor<i1>) -> ()
49 // "tf.IfRegion"(%0) ( {
50 // %2 = "tf.B"() : () -> (tensor<f32>)
51 // "tf.Yield"() : () -> ()
52 // }, {
53 // "tf.Yield"() : () -> ()
54 // }) { is_stateless = true } : (tensor<i1>) -> ()
55 // Would become:
56 // "tf.IfRegion"(%0) ( {
57 // %2 = "tf.A"() : () -> (tensor<f32>)
58 // %3 = "tf.B"() : () -> (tensor<f32>)
59 // "tf.Yield"() : () -> ()
60 // }, {
61 // "tf.Yield"() : () -> ()
62 // }) { is_stateless = true } : (tensor<i1>) -> ()
63
64 struct MergeControlFlow : public TF::PerFunctionAggregateAnalysisConsumerPass<
65 MergeControlFlow, TF::SideEffectAnalysis> {
66 void runOnFunction(FuncOp func,
67 const TF::SideEffectAnalysis::Info& side_effect_analysis);
68 };
69
70 // Returns whether it is safe to merge `source` IfRegion into `destination`
71 // IfRegion. `source` must come after `destination`.
SafeToMerge(TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)72 bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination,
73 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
74 // IfRegion ops must be in the same block.
75 if (source.getOperation()->getBlock() !=
76 destination.getOperation()->getBlock())
77 return false;
78 assert(destination.getOperation()->isBeforeInBlock(source.getOperation()));
79
80 llvm::SmallSetVector<Operation*, 4> source_ops;
81 source_ops.insert(source);
82 for (Operation& op : source.then_branch().front()) {
83 source_ops.insert(&op);
84 }
85 for (Operation& op : source.else_branch().front()) {
86 source_ops.insert(&op);
87 }
88
89 // If there is an intermediate data or side effect dependency between the
90 // ops in destination and the ops in the source, it's not safe to merge
91 // them.
92 std::vector<Operation*> dependencies;
93 for (auto* user : destination.getOperation()->getUsers()) {
94 if (!source_ops.contains(user)) dependencies.push_back(user);
95 }
96 for (Operation& op : destination.then_branch().front()) {
97 for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
98 if (!source_ops.contains(successor)) dependencies.push_back(successor);
99 }
100 }
101 for (Operation& op : destination.else_branch().front()) {
102 for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
103 if (!source_ops.contains(successor)) dependencies.push_back(successor);
104 }
105 }
106
107 bool safe_to_merge = true;
108
109 llvm::SmallPtrSet<Operation*, 4> visited;
110 while (!dependencies.empty()) {
111 Operation* dependency = dependencies.back();
112 dependencies.pop_back();
113 if (visited.count(dependency)) continue;
114 visited.insert(dependency);
115 for (auto* user : dependency->getUsers()) {
116 if (source_ops.contains(user)) {
117 safe_to_merge = false;
118 break;
119 } else {
120 dependencies.push_back(user);
121 }
122 }
123 for (auto* successor :
124 side_effect_analysis.DirectControlSuccessors(dependency)) {
125 if (source_ops.contains(successor)) {
126 safe_to_merge = false;
127 break;
128 } else {
129 dependencies.push_back(successor);
130 }
131 }
132 // If the op is nested, then also consider the users and successors of the
133 // parent op.
134 if (dependency->getBlock() != destination.getOperation()->getBlock())
135 dependencies.push_back(dependency->getParentOp());
136 if (!safe_to_merge) break;
137 }
138 return safe_to_merge;
139 }
140
141 // Checks whether a return indice should be kep for `first_if_op` by checking
142 // for results in `second_if_op`.
GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,TF::IfRegionOp second_if_op)143 llvm::SmallVector<int, 4> GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,
144 TF::IfRegionOp second_if_op) {
145 llvm::SmallVector<int, 4> return_indices_to_keep;
146 for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) {
147 if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) {
148 return second_if_op->isProperAncestor(op);
149 })) {
150 return_indices_to_keep.push_back(index_and_value.index());
151 }
152 }
153 return return_indices_to_keep;
154 }
155
156 // Move the body excluding the terminators of else and then regions from
157 // 'source' to 'destination'.
MoveBranches(TF::IfRegionOp source,TF::IfRegionOp destination)158 void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
159 Block& destination_then_block = destination.then_branch().front();
160 auto& source_then_body = source.then_branch().front().getOperations();
161 destination_then_block.getOperations().splice(
162 destination_then_block.without_terminator().end(), source_then_body,
163 source_then_body.begin(), std::prev(source_then_body.end()));
164
165 Block& destination_else_block = destination.else_branch().front();
166 auto& source_else_body = source.else_branch().front().getOperations();
167 destination_else_block.getOperations().splice(
168 destination_else_block.without_terminator().end(), source_else_body,
169 source_else_body.begin(), std::prev(source_else_body.end()));
170 }
171
172 // Move all ops that depends on the results from `result_op` after `after_op`.
MoveResultsAfter(Operation * result_op,Operation * after_op,const TF::SideEffectAnalysis::Info & side_effect_analysis)173 void MoveResultsAfter(
174 Operation* result_op, Operation* after_op,
175 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
176 std::queue<Operation*> queue;
177
178 auto enqueue_deps = [&](Operation* source_op) {
179 for (Operation* user : source_op->getUsers()) {
180 queue.push(user);
181 }
182 source_op->walk([&](Operation* walked_op) {
183 for (Operation* successor :
184 side_effect_analysis.DirectControlSuccessors(walked_op)) {
185 if (!source_op->isProperAncestor(successor)) queue.push(successor);
186 }
187 });
188 };
189 enqueue_deps(result_op);
190
191 while (!queue.empty()) {
192 auto* op = queue.front();
193 queue.pop();
194 while (op->getBlock() != after_op->getBlock()) op = op->getParentOp();
195 if (op->isBeforeInBlock(after_op)) {
196 op->moveAfter(after_op);
197 after_op = op;
198 enqueue_deps(op);
199 }
200 }
201 }
202
CreateMergedIf(ArrayRef<int> source_return_indices_to_keep,ArrayRef<int> destination_return_indices_to_keep,TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)203 TF::IfRegionOp CreateMergedIf(
204 ArrayRef<int> source_return_indices_to_keep,
205 ArrayRef<int> destination_return_indices_to_keep, TF::IfRegionOp source,
206 TF::IfRegionOp destination,
207 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
208 llvm::SmallVector<Type, 4> merged_return_types;
209 for (int i : destination_return_indices_to_keep)
210 merged_return_types.push_back(destination.getResult(i).getType());
211 for (int i : source_return_indices_to_keep)
212 merged_return_types.push_back(source.getResult(i).getType());
213
214 OpBuilder builder(destination);
215 // Create new IfRegion with correct merged results.
216 builder.setInsertionPoint(source.getOperation());
217
218 auto new_if_op = builder.create<TF::IfRegionOp>(
219 destination.getLoc(), merged_return_types, destination.cond(),
220 destination.is_stateless() && source.is_stateless(),
221 destination._then_func_nameAttr(), destination._else_func_nameAttr());
222 new_if_op.then_branch().push_back(new Block);
223 new_if_op.else_branch().push_back(new Block);
224 // Replace internal usages of merged if ops.
225 for (OpResult result : destination.getResults()) {
226 replaceAllUsesInRegionWith(
227 result,
228 destination.then_branch().front().getTerminator()->getOperand(
229 result.getResultNumber()),
230 source.then_branch());
231 replaceAllUsesInRegionWith(
232 result,
233 destination.else_branch().front().getTerminator()->getOperand(
234 result.getResultNumber()),
235 source.else_branch());
236 }
237
238 MoveResultsAfter(destination.getOperation(), new_if_op.getOperation(),
239 side_effect_analysis);
240
241 // Replace external usages of merged if ops.
242 int new_return_index = 0;
243 for (int i : destination_return_indices_to_keep) {
244 destination.getResult(i).replaceAllUsesWith(
245 new_if_op.getResult(new_return_index++));
246 }
247 for (int i : source_return_indices_to_keep) {
248 source.getResult(i).replaceAllUsesWith(
249 new_if_op.getResult(new_return_index++));
250 }
251
252 // Create the Yield ops for both branches with merged results.
253 llvm::SmallVector<Value, 4> merged_then_yield_values;
254 for (int i : destination_return_indices_to_keep)
255 merged_then_yield_values.push_back(
256 destination.then_branch().front().getTerminator()->getOperand(i));
257 for (int i : source_return_indices_to_keep)
258 merged_then_yield_values.push_back(
259 source.then_branch().front().getTerminator()->getOperand(i));
260 builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
261 builder.create<TF::YieldOp>(
262 destination.then_branch().front().getTerminator()->getLoc(),
263 /*operands=*/merged_then_yield_values);
264
265 llvm::SmallVector<Value, 4> merged_else_yield_values;
266 for (int i : destination_return_indices_to_keep)
267 merged_else_yield_values.push_back(
268 destination.else_branch().front().getTerminator()->getOperand(i));
269 for (int i : source_return_indices_to_keep)
270 merged_else_yield_values.push_back(
271 source.else_branch().front().getTerminator()->getOperand(i));
272 builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
273 builder.create<TF::YieldOp>(
274 destination.else_branch().front().getTerminator()->getLoc(),
275 /*operands=*/merged_else_yield_values);
276
277 // Merge the two branch regions from both IfRegionOps into new IfRegionOp.
278 MoveBranches(/*source=*/destination, /*destination=*/new_if_op);
279 destination.erase();
280 MoveBranches(/*source=*/source, /*destination=*/new_if_op);
281 source.erase();
282 return new_if_op;
283 }
284
285 // Groups if regions by common predicate and attemps to merge them.
OptimizeIfRegions(Block * block,const TF::SideEffectAnalysis::Info & side_effect_analysis)286 void OptimizeIfRegions(
287 Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
288 // Determine IfRegions with the same predicate.
289 llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
290 grouped_if_ops;
291 block->walk([&](TF::IfRegionOp if_op) {
292 auto it = grouped_if_ops.try_emplace(if_op.cond());
293 it.first->getSecond().push_back(if_op);
294 });
295
296 for (auto& entry : grouped_if_ops) {
297 auto& if_ops = entry.second;
298 for (auto it = if_ops.begin(); it != if_ops.end(); ++it) {
299 TF::IfRegionOp first_if_op = *it;
300 for (auto it2 = std::next(it); it2 != if_ops.end(); ++it2) {
301 TF::IfRegionOp second_if_op = *it2;
302 if (!SafeToMerge(second_if_op, first_if_op, side_effect_analysis))
303 break;
304
305 // For both check if there are uses outside of IfRegion, keep these as
306 // part of the return and replace the internal uses.
307 auto first_return_indices_to_keep =
308 GetReturnIndicesToKeep(first_if_op, second_if_op);
309 auto second_return_indices_to_keep =
310 GetReturnIndicesToKeep(second_if_op, first_if_op);
311
312 auto new_if_op = CreateMergedIf(
313 second_return_indices_to_keep, first_return_indices_to_keep,
314 second_if_op, first_if_op, side_effect_analysis);
315
316 if_ops.erase(it2--);
317 first_if_op = new_if_op;
318 }
319 }
320 }
321 }
322
runOnFunction(FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)323 void MergeControlFlow::runOnFunction(
324 FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
325 auto result = func.walk([&](tf_device::ClusterOp cluster) {
326 OptimizeIfRegions(&cluster.GetBody(), side_effect_analysis);
327 return WalkResult::advance();
328 });
329
330 if (result.wasInterrupted()) return signalPassFailure();
331 }
332
333 } // namespace
334
CreateMergeControlFlowPass()335 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
336 return std::make_unique<MergeControlFlow>();
337 }
338
339 static PassRegistration<MergeControlFlow> pass(
340 "tf-merge-control-flow", "Merges control flow with a common predicate.");
341 } // namespace TFDevice
342 } // namespace mlir
343