• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
33 
34 namespace mlir {
35 namespace TFDevice {
36 
37 namespace {
38 
39 struct MergeControlFlowPass
40     : public TF::MergeControlFlowPassBase<MergeControlFlowPass> {
41   void runOnOperation() override;
42 };
43 
44 // Gets the IfRegion op and all of ops in the then and else branches.
GetAllOpsFromIf(TF::IfRegionOp if_op)45 llvm::SmallSetVector<Operation*, 4> GetAllOpsFromIf(TF::IfRegionOp if_op) {
46   llvm::SmallSetVector<Operation*, 4> all_ops;
47   all_ops.insert(if_op);
48   for (Operation& op : if_op.then_branch().front()) {
49     all_ops.insert(&op);
50   }
51   for (Operation& op : if_op.else_branch().front()) {
52     all_ops.insert(&op);
53   }
54   return all_ops;
55 }
56 
57 // Returns whether it is safe to merge `source` IfRegion into `destination`
58 // IfRegion. `source` must come after `destination`.
SafeToMerge(TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)59 bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination,
60                  const TF::SideEffectAnalysis::Info& side_effect_analysis) {
61   // IfRegion ops must be in the same block.
62   if (source.getOperation()->getBlock() !=
63       destination.getOperation()->getBlock())
64     return false;
65   assert(destination.getOperation()->isBeforeInBlock(source.getOperation()));
66 
67   llvm::SmallSetVector<Operation*, 4> source_ops = GetAllOpsFromIf(source);
68   llvm::SmallSetVector<Operation*, 4> destination_ops =
69       GetAllOpsFromIf(destination);
70 
71   // If there is an intermediate data or side effect dependency between the
72   // ops in destination and the ops in the source, it's not safe to merge
73   // them.
74   std::vector<Operation*> dependencies;
75   for (auto* user : destination.getOperation()->getUsers()) {
76     if (!source_ops.contains(user)) dependencies.push_back(user);
77   }
78   for (auto* successor : side_effect_analysis.DirectControlSuccessors(
79            destination.getOperation())) {
80     if (!source_ops.contains(successor)) dependencies.push_back(successor);
81   }
82   for (Operation& op : destination.then_branch().front()) {
83     for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
84       if (!source_ops.contains(successor) &&
85           !destination_ops.contains(successor))
86         dependencies.push_back(successor);
87     }
88   }
89   for (Operation& op : destination.else_branch().front()) {
90     for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
91       if (!source_ops.contains(successor) &&
92           !destination_ops.contains(successor))
93         dependencies.push_back(successor);
94     }
95   }
96 
97   bool safe_to_merge = true;
98 
99   llvm::SmallPtrSet<Operation*, 4> visited;
100   while (!dependencies.empty()) {
101     Operation* dependency = dependencies.back();
102     dependencies.pop_back();
103     if (visited.count(dependency)) continue;
104     visited.insert(dependency);
105     for (auto* user : dependency->getUsers()) {
106       if (source_ops.contains(user)) {
107         safe_to_merge = false;
108         break;
109       } else {
110         dependencies.push_back(user);
111       }
112     }
113     for (auto* successor :
114          side_effect_analysis.DirectControlSuccessors(dependency)) {
115       if (source_ops.contains(successor)) {
116         safe_to_merge = false;
117         break;
118       } else {
119         dependencies.push_back(successor);
120       }
121     }
122     // If the op is nested, then also consider the users and successors of the
123     // parent op.
124     if (dependency->getBlock() != destination.getOperation()->getBlock())
125       dependencies.push_back(dependency->getParentOp());
126     if (!safe_to_merge) break;
127   }
128   return safe_to_merge;
129 }
130 
131 // Checks whether a return indice should be kep for `first_if_op` by checking
132 // for results in `second_if_op`.
GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,TF::IfRegionOp second_if_op)133 llvm::SmallVector<int, 4> GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,
134                                                  TF::IfRegionOp second_if_op) {
135   llvm::SmallVector<int, 4> return_indices_to_keep;
136   for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) {
137     if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) {
138           return second_if_op->isProperAncestor(op);
139         })) {
140       return_indices_to_keep.push_back(index_and_value.index());
141     }
142   }
143   return return_indices_to_keep;
144 }
145 
146 // Move the body excluding the terminators of else and then regions from
147 // 'source' to 'destination'.
MoveBranches(TF::IfRegionOp source,TF::IfRegionOp destination)148 void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
149   Block& destination_then_block = destination.then_branch().front();
150   auto& source_then_body = source.then_branch().front().getOperations();
151   destination_then_block.getOperations().splice(
152       destination_then_block.without_terminator().end(), source_then_body,
153       source_then_body.begin(), std::prev(source_then_body.end()));
154 
155   Block& destination_else_block = destination.else_branch().front();
156   auto& source_else_body = source.else_branch().front().getOperations();
157   destination_else_block.getOperations().splice(
158       destination_else_block.without_terminator().end(), source_else_body,
159       source_else_body.begin(), std::prev(source_else_body.end()));
160 }
161 
162 // Move all ops that depends on the results from `result_op` after `after_op`.
MoveResultsAfter(Operation * result_op,Operation * after_op,const TF::SideEffectAnalysis::Info & side_effect_analysis)163 void MoveResultsAfter(
164     Operation* result_op, Operation* after_op,
165     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
166   std::queue<Operation*> queue;
167 
168   auto enqueue_deps = [&](Operation* source_op) {
169     for (Operation* user : source_op->getUsers()) {
170       queue.push(user);
171     }
172     source_op->walk([&](Operation* walked_op) {
173       for (Operation* successor :
174            side_effect_analysis.DirectControlSuccessors(walked_op)) {
175         if (!source_op->isProperAncestor(successor)) queue.push(successor);
176       }
177     });
178   };
179   enqueue_deps(result_op);
180 
181   while (!queue.empty()) {
182     auto* op = queue.front();
183     queue.pop();
184     while (op->getBlock() != after_op->getBlock()) op = op->getParentOp();
185     if (op->isBeforeInBlock(after_op)) {
186       op->moveAfter(after_op);
187       after_op = op;
188       enqueue_deps(op);
189     }
190   }
191 }
192 
CreateMergedIf(ArrayRef<int> source_return_indices_to_keep,ArrayRef<int> destination_return_indices_to_keep,TF::IfRegionOp source,TF::IfRegionOp destination,const TF::SideEffectAnalysis::Info & side_effect_analysis)193 TF::IfRegionOp CreateMergedIf(
194     ArrayRef<int> source_return_indices_to_keep,
195     ArrayRef<int> destination_return_indices_to_keep, TF::IfRegionOp source,
196     TF::IfRegionOp destination,
197     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
198   llvm::SmallVector<Type, 4> merged_return_types;
199   for (int i : destination_return_indices_to_keep)
200     merged_return_types.push_back(destination.getResult(i).getType());
201   for (int i : source_return_indices_to_keep)
202     merged_return_types.push_back(source.getResult(i).getType());
203 
204   OpBuilder builder(destination);
205   // Create new IfRegion with correct merged results.
206   builder.setInsertionPoint(source.getOperation());
207 
208   auto new_if_op = builder.create<TF::IfRegionOp>(
209       destination.getLoc(), merged_return_types, destination.cond(),
210       destination.is_stateless() && source.is_stateless(),
211       destination._then_func_nameAttr(), destination._else_func_nameAttr());
212   new_if_op.then_branch().push_back(new Block);
213   new_if_op.else_branch().push_back(new Block);
214   // Replace internal usages of merged if ops.
215   for (OpResult result : destination.getResults()) {
216     replaceAllUsesInRegionWith(
217         result,
218         destination.then_branch().front().getTerminator()->getOperand(
219             result.getResultNumber()),
220         source.then_branch());
221     replaceAllUsesInRegionWith(
222         result,
223         destination.else_branch().front().getTerminator()->getOperand(
224             result.getResultNumber()),
225         source.else_branch());
226   }
227 
228   MoveResultsAfter(destination.getOperation(), new_if_op.getOperation(),
229                    side_effect_analysis);
230 
231   // Replace external usages of merged if ops.
232   int new_return_index = 0;
233   for (int i : destination_return_indices_to_keep) {
234     destination.getResult(i).replaceAllUsesWith(
235         new_if_op.getResult(new_return_index++));
236   }
237   for (int i : source_return_indices_to_keep) {
238     source.getResult(i).replaceAllUsesWith(
239         new_if_op.getResult(new_return_index++));
240   }
241 
242   // Create the Yield ops for both branches with merged results.
243   llvm::SmallVector<Value, 4> merged_then_yield_values;
244   for (int i : destination_return_indices_to_keep)
245     merged_then_yield_values.push_back(
246         destination.then_branch().front().getTerminator()->getOperand(i));
247   for (int i : source_return_indices_to_keep)
248     merged_then_yield_values.push_back(
249         source.then_branch().front().getTerminator()->getOperand(i));
250   builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
251   builder.create<TF::YieldOp>(
252       destination.then_branch().front().getTerminator()->getLoc(),
253       /*operands=*/merged_then_yield_values);
254 
255   llvm::SmallVector<Value, 4> merged_else_yield_values;
256   for (int i : destination_return_indices_to_keep)
257     merged_else_yield_values.push_back(
258         destination.else_branch().front().getTerminator()->getOperand(i));
259   for (int i : source_return_indices_to_keep)
260     merged_else_yield_values.push_back(
261         source.else_branch().front().getTerminator()->getOperand(i));
262   builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
263   builder.create<TF::YieldOp>(
264       destination.else_branch().front().getTerminator()->getLoc(),
265       /*operands=*/merged_else_yield_values);
266 
267   // Merge the two branch regions from both IfRegionOps into new IfRegionOp.
268   MoveBranches(/*source=*/destination, /*destination=*/new_if_op);
269   destination.erase();
270   MoveBranches(/*source=*/source, /*destination=*/new_if_op);
271   source.erase();
272   return new_if_op;
273 }
274 
275 // Groups if regions by common predicate and attemps to merge them.
OptimizeIfRegions(Block * block,ModuleOp module)276 void OptimizeIfRegions(Block* block, ModuleOp module) {
277   // Determine IfRegions with the same predicate.
278   llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
279       grouped_if_ops;
280   block->walk([&](TF::IfRegionOp if_op) {
281     auto it = grouped_if_ops.try_emplace(if_op.cond());
282     it.first->getSecond().push_back(if_op);
283   });
284 
285   auto side_effect_analysis = std::make_unique<TF::SideEffectAnalysis>(module);
286 
287   for (auto& entry : grouped_if_ops) {
288     auto& if_ops = entry.second;
289     for (auto it = if_ops.begin(); it != if_ops.end(); ++it) {
290       TF::IfRegionOp first_if_op = *it;
291       for (auto it2 = std::next(it); it2 != if_ops.end(); ++it2) {
292         FuncOp func = first_if_op->getParentOfType<FuncOp>();
293         const TF::SideEffectAnalysis::Info& analysis =
294             side_effect_analysis->GetAnalysisForFunc(func);
295 
296         TF::IfRegionOp second_if_op = *it2;
297         if (!SafeToMerge(second_if_op, first_if_op, analysis)) break;
298 
299         // For both check if there are uses outside of IfRegion, keep these as
300         // part of the return and replace the internal uses.
301         auto first_return_indices_to_keep =
302             GetReturnIndicesToKeep(first_if_op, second_if_op);
303         auto second_return_indices_to_keep =
304             GetReturnIndicesToKeep(second_if_op, first_if_op);
305 
306         auto new_if_op = CreateMergedIf(second_return_indices_to_keep,
307                                         first_return_indices_to_keep,
308                                         second_if_op, first_if_op, analysis);
309 
310         if_ops.erase(it2--);
311         first_if_op = new_if_op;
312         // We regenerate the side effect analysis since merging the IfRegions
313         // invalidates the side effect analysis.  This approach is O(N*M) where
314         // N is the number of ops in `module` and M is the number of pairs of
315         // IfRegion ops that are merged.
316         side_effect_analysis = std::make_unique<TF::SideEffectAnalysis>(module);
317       }
318     }
319   }
320 }
321 
runOnOperation()322 void MergeControlFlowPass::runOnOperation() {
323   ModuleOp module = getOperation();
324   auto result = module.walk([&](tf_device::ClusterOp cluster) {
325     OptimizeIfRegions(&cluster.GetBody(), module);
326     return WalkResult::advance();
327   });
328 
329   if (result.wasInterrupted()) return signalPassFailure();
330 }
331 
332 }  // namespace
333 
CreateMergeControlFlowPass()334 std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
335   return std::make_unique<MergeControlFlowPass>();
336 }
337 
338 }  // namespace TFDevice
339 }  // namespace mlir
340