• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
32 
33 namespace mlir {
34 namespace TFDevice {
35 
36 namespace {
37 
38 // This pass merges IfRegion ops together if they have the same predicate and it
39 // is safe to do so (there are no intermediate dependencies, they are in the
40 // same block, etc).
41 //
42 // A simple example:
43 //    "tf.IfRegion"(%0) ( {
44 //      %2 = "tf.A"() : () -> (tensor<f32>)
45 //      "tf.Yield"() : () -> ()
46 //      }, {
47 //      "tf.Yield"() : () -> ()
48 //     }) { is_stateless = true } : (tensor<i1>) -> ()
49 //    "tf.IfRegion"(%0) ( {
50 //      %2 = "tf.B"() : () -> (tensor<f32>)
51 //      "tf.Yield"() : () -> ()
52 //      }, {
53 //      "tf.Yield"() : () -> ()
54 //     }) { is_stateless = true } : (tensor<i1>) -> ()
55 // Would become:
56 //    "tf.IfRegion"(%0) ( {
57 //      %2 = "tf.A"() : () -> (tensor<f32>)
58 //      %3 = "tf.B"() : () -> (tensor<f32>)
59 //      "tf.Yield"() : () -> ()
60 //      }, {
61 //      "tf.Yield"() : () -> ()
62 //     }) { is_stateless = true } : (tensor<i1>) -> ()
63 
64 struct MergeControlFlow : public TF::PerFunctionAggregateAnalysisConsumerPass<
65                               MergeControlFlow, TF::SideEffectAnalysis> {
66   void runOnFunction(FuncOp func,
67                      const TF::SideEffectAnalysis::Info& side_effect_analysis);
68 };
69 
70 // Returns whether it is safe to merge `source` IfRegion into `destination`
71 // IfRegion. `source` must come after `destination`.
SafeToMerge(TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)72 bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination,
73                  const TF::SideEffectAnalysis::Info& side_effect_analysis) {
74   // IfRegion ops must be in the same block.
75   if (source.getOperation()->getBlock() !=
76       destination.getOperation()->getBlock())
77     return false;
78   assert(destination.getOperation()->isBeforeInBlock(source.getOperation()));
79 
80   llvm::SmallSetVector<Operation*, 4> source_ops;
81   source_ops.insert(source);
82   for (Operation& op : source.then_branch().front()) {
83     source_ops.insert(&op);
84   }
85   for (Operation& op : source.else_branch().front()) {
86     source_ops.insert(&op);
87   }
88 
89   // If there is an intermediate data or side effect dependency between the
90   // ops in destination and the ops in the source, it's not safe to merge
91   // them.
92   std::vector<Operation*> dependencies;
93   for (auto* user : destination.getOperation()->getUsers()) {
94     if (!source_ops.contains(user)) dependencies.push_back(user);
95   }
96   for (Operation& op : destination.then_branch().front()) {
97     for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
98       if (!source_ops.contains(successor)) dependencies.push_back(successor);
99     }
100   }
101   for (Operation& op : destination.else_branch().front()) {
102     for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
103       if (!source_ops.contains(successor)) dependencies.push_back(successor);
104     }
105   }
106 
107   bool safe_to_merge = true;
108 
109   llvm::SmallPtrSet<Operation*, 4> visited;
110   while (!dependencies.empty()) {
111     Operation* dependency = dependencies.back();
112     dependencies.pop_back();
113     if (visited.count(dependency)) continue;
114     visited.insert(dependency);
115     for (auto* user : dependency->getUsers()) {
116       if (source_ops.contains(user)) {
117         safe_to_merge = false;
118         break;
119       } else {
120         dependencies.push_back(user);
121       }
122     }
123     for (auto* successor :
124          side_effect_analysis.DirectControlSuccessors(dependency)) {
125       if (source_ops.contains(successor)) {
126         safe_to_merge = false;
127         break;
128       } else {
129         dependencies.push_back(successor);
130       }
131     }
132     // If the op is nested, then also consider the users and successors of the
133     // parent op.
134     if (dependency->getBlock() != destination.getOperation()->getBlock())
135       dependencies.push_back(dependency->getParentOp());
136     if (!safe_to_merge) break;
137   }
138   return safe_to_merge;
139 }
140 
141 // Checks whether a return indice should be kep for `first_if_op` by checking
142 // for results in `second_if_op`.
GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,TF::IfRegionOp second_if_op)143 llvm::SmallVector<int, 4> GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,
144                                                  TF::IfRegionOp second_if_op) {
145   llvm::SmallVector<int, 4> return_indices_to_keep;
146   for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) {
147     if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) {
148           return second_if_op->isProperAncestor(op);
149         })) {
150       return_indices_to_keep.push_back(index_and_value.index());
151     }
152   }
153   return return_indices_to_keep;
154 }
155 
156 // Move the body excluding the terminators of else and then regions from
157 // 'source' to 'destination'.
MoveBranches(TF::IfRegionOp source,TF::IfRegionOp destination)158 void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
159   Block& destination_then_block = destination.then_branch().front();
160   auto& source_then_body = source.then_branch().front().getOperations();
161   destination_then_block.getOperations().splice(
162       destination_then_block.without_terminator().end(), source_then_body,
163       source_then_body.begin(), std::prev(source_then_body.end()));
164 
165   Block& destination_else_block = destination.else_branch().front();
166   auto& source_else_body = source.else_branch().front().getOperations();
167   destination_else_block.getOperations().splice(
168       destination_else_block.without_terminator().end(), source_else_body,
169       source_else_body.begin(), std::prev(source_else_body.end()));
170 }
171 
172 // Move all ops that depends on the results from `result_op` after `after_op`.
MoveResultsAfter(Operation * result_op,Operation * after_op,const TF::SideEffectAnalysis::Info & side_effect_analysis)173 void MoveResultsAfter(
174     Operation* result_op, Operation* after_op,
175     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
176   std::queue<Operation*> queue;
177 
178   auto enqueue_deps = [&](Operation* source_op) {
179     for (Operation* user : source_op->getUsers()) {
180       queue.push(user);
181     }
182     source_op->walk([&](Operation* walked_op) {
183       for (Operation* successor :
184            side_effect_analysis.DirectControlSuccessors(walked_op)) {
185         if (!source_op->isProperAncestor(successor)) queue.push(successor);
186       }
187     });
188   };
189   enqueue_deps(result_op);
190 
191   while (!queue.empty()) {
192     auto* op = queue.front();
193     queue.pop();
194     while (op->getBlock() != after_op->getBlock()) op = op->getParentOp();
195     if (op->isBeforeInBlock(after_op)) {
196       op->moveAfter(after_op);
197       after_op = op;
198       enqueue_deps(op);
199     }
200   }
201 }
202 
CreateMergedIf(ArrayRef<int> source_return_indices_to_keep,ArrayRef<int> destination_return_indices_to_keep,TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)203 TF::IfRegionOp CreateMergedIf(
204     ArrayRef<int> source_return_indices_to_keep,
205     ArrayRef<int> destination_return_indices_to_keep, TF::IfRegionOp source,
206     TF::IfRegionOp destination,
207     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
208   llvm::SmallVector<Type, 4> merged_return_types;
209   for (int i : destination_return_indices_to_keep)
210     merged_return_types.push_back(destination.getResult(i).getType());
211   for (int i : source_return_indices_to_keep)
212     merged_return_types.push_back(source.getResult(i).getType());
213 
214   OpBuilder builder(destination);
215   // Create new IfRegion with correct merged results.
216   builder.setInsertionPoint(source.getOperation());
217 
218   auto new_if_op = builder.create<TF::IfRegionOp>(
219       destination.getLoc(), merged_return_types, destination.cond(),
220       destination.is_stateless() && source.is_stateless(),
221       destination._then_func_nameAttr(), destination._else_func_nameAttr());
222   new_if_op.then_branch().push_back(new Block);
223   new_if_op.else_branch().push_back(new Block);
224   // Replace internal usages of merged if ops.
225   for (OpResult result : destination.getResults()) {
226     replaceAllUsesInRegionWith(
227         result,
228         destination.then_branch().front().getTerminator()->getOperand(
229             result.getResultNumber()),
230         source.then_branch());
231     replaceAllUsesInRegionWith(
232         result,
233         destination.else_branch().front().getTerminator()->getOperand(
234             result.getResultNumber()),
235         source.else_branch());
236   }
237 
238   MoveResultsAfter(destination.getOperation(), new_if_op.getOperation(),
239                    side_effect_analysis);
240 
241   // Replace external usages of merged if ops.
242   int new_return_index = 0;
243   for (int i : destination_return_indices_to_keep) {
244     destination.getResult(i).replaceAllUsesWith(
245         new_if_op.getResult(new_return_index++));
246   }
247   for (int i : source_return_indices_to_keep) {
248     source.getResult(i).replaceAllUsesWith(
249         new_if_op.getResult(new_return_index++));
250   }
251 
252   // Create the Yield ops for both branches with merged results.
253   llvm::SmallVector<Value, 4> merged_then_yield_values;
254   for (int i : destination_return_indices_to_keep)
255     merged_then_yield_values.push_back(
256         destination.then_branch().front().getTerminator()->getOperand(i));
257   for (int i : source_return_indices_to_keep)
258     merged_then_yield_values.push_back(
259         source.then_branch().front().getTerminator()->getOperand(i));
260   builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
261   builder.create<TF::YieldOp>(
262       destination.then_branch().front().getTerminator()->getLoc(),
263       /*operands=*/merged_then_yield_values);
264 
265   llvm::SmallVector<Value, 4> merged_else_yield_values;
266   for (int i : destination_return_indices_to_keep)
267     merged_else_yield_values.push_back(
268         destination.else_branch().front().getTerminator()->getOperand(i));
269   for (int i : source_return_indices_to_keep)
270     merged_else_yield_values.push_back(
271         source.else_branch().front().getTerminator()->getOperand(i));
272   builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
273   builder.create<TF::YieldOp>(
274       destination.else_branch().front().getTerminator()->getLoc(),
275       /*operands=*/merged_else_yield_values);
276 
277   // Merge the two branch regions from both IfRegionOps into new IfRegionOp.
278   MoveBranches(/*source=*/destination, /*destination=*/new_if_op);
279   destination.erase();
280   MoveBranches(/*source=*/source, /*destination=*/new_if_op);
281   source.erase();
282   return new_if_op;
283 }
284 
285 // Groups if regions by common predicate and attemps to merge them.
OptimizeIfRegions(Block * block,const TF::SideEffectAnalysis::Info & side_effect_analysis)286 void OptimizeIfRegions(
287     Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
288   // Determine IfRegions with the same predicate.
289   llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
290       grouped_if_ops;
291   block->walk([&](TF::IfRegionOp if_op) {
292     auto it = grouped_if_ops.try_emplace(if_op.cond());
293     it.first->getSecond().push_back(if_op);
294   });
295 
296   for (auto& entry : grouped_if_ops) {
297     auto& if_ops = entry.second;
298     for (auto it = if_ops.begin(); it != if_ops.end(); ++it) {
299       TF::IfRegionOp first_if_op = *it;
300       for (auto it2 = std::next(it); it2 != if_ops.end(); ++it2) {
301         TF::IfRegionOp second_if_op = *it2;
302         if (!SafeToMerge(second_if_op, first_if_op, side_effect_analysis))
303           break;
304 
305         // For both check if there are uses outside of IfRegion, keep these as
306         // part of the return and replace the internal uses.
307         auto first_return_indices_to_keep =
308             GetReturnIndicesToKeep(first_if_op, second_if_op);
309         auto second_return_indices_to_keep =
310             GetReturnIndicesToKeep(second_if_op, first_if_op);
311 
312         auto new_if_op = CreateMergedIf(
313             second_return_indices_to_keep, first_return_indices_to_keep,
314             second_if_op, first_if_op, side_effect_analysis);
315 
316         if_ops.erase(it2--);
317         first_if_op = new_if_op;
318       }
319     }
320   }
321 }
322 
runOnFunction(FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)323 void MergeControlFlow::runOnFunction(
324     FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
325   auto result = func.walk([&](tf_device::ClusterOp cluster) {
326     OptimizeIfRegions(&cluster.GetBody(), side_effect_analysis);
327     return WalkResult::advance();
328   });
329 
330   if (result.wasInterrupted()) return signalPassFailure();
331 }
332 
333 }  // namespace
334 
CreateMergeControlFlowPass()335 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
336   return std::make_unique<MergeControlFlow>();
337 }
338 
339 static PassRegistration<MergeControlFlow> pass(
340     "tf-merge-control-flow", "Merges control flow with a common predicate.");
341 }  // namespace TFDevice
342 }  // namespace mlir
343