• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 
33 // This pass is used in preparation for Graph export.
34 // The GraphDef exporter expects each op to be in its own island.
35 // This pass puts the IR in that form.
36 //
37 // We do this as an IR->IR transform to keep the Graph exporter as simple as
38 // possible.
39 
40 namespace mlir {
41 
42 namespace {
43 
44 class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
45                            BreakUpIslands, TF::SideEffectAnalysis> {
getDependentDialects(DialectRegistry & registry) const46   void getDependentDialects(DialectRegistry& registry) const override {
47     registry.insert<tf_executor::TensorFlowExecutorDialect>();
48   }
49 
50  public:
51   void runOnFunction(FuncOp func,
52                      const TF::SideEffectAnalysis::Info& side_effect_analysis);
53 
54   void BreakUpIsland(tf_executor::IslandOp island_op,
55                      const TF::SideEffectAnalysis::Info& side_effect_analysis,
56                      llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
57                          new_control_inputs);
58 };
59 
runOnFunction(FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)60 void BreakUpIslands::runOnFunction(
61     FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
62   auto graph_op_range = func.front().without_terminator();
63   tf_executor::GraphOp graph_op;
64 
65   if (llvm::hasSingleElement(graph_op_range))
66     graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
67 
68   if (!graph_op) {
69     func.emitError("expected function to contain only a graph_op");
70     signalPassFailure();
71     return;
72   }
73 
74   // New control inputs to be added. For an operation x, new_control_inputs[x]
75   // contains all control inputs that need to be added to x as operands.
76   llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
77   // Iterate in reverse order to avoid invalidating Operation* stored in
78   // new_control_inputs.
79   for (auto& item :
80        llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
81     if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
82       BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
83     }
84   }
85   OpBuilder builder(func);
86 
87   // For every op, add new control inputs in reverse order so that the ops don't
88   // get invalidated.
89   llvm::SmallVector<Value, 8> operands;
90   llvm::SmallPtrSet<Operation*, 4> defining_ops;
91   llvm::SmallVector<Type, 4> types;
92   for (auto& item :
93        llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
94     auto it = new_control_inputs.find(&item);
95     if (it == new_control_inputs.end()) continue;
96     auto& new_control_inputs_for_item = it->second;
97     builder.setInsertionPoint(&item);
98     OperationState state(item.getLoc(), item.getName());
99     types.assign(item.result_type_begin(), item.result_type_end());
100     state.addTypes(types);
101     for (Region& region : item.getRegions()) {
102       state.addRegion()->takeBody(region);
103     }
104     // Assign existing operands for item.
105     operands.assign(item.operand_begin(), item.operand_end());
106 
107     // Collect defining ops for existing operands.
108     defining_ops.clear();
109     for (Value operand : operands) {
110       defining_ops.insert(operand.getDefiningOp());
111     }
112     for (Value new_control_input : llvm::reverse(new_control_inputs_for_item)) {
113       // Add new control input if its defining op is not already a defining
114       // op for some other operand. Update defining_ops.
115       if (defining_ops.insert(new_control_input.getDefiningOp()).second) {
116         operands.push_back(new_control_input);
117       }
118     }
119     state.addOperands(operands);
120     Operation* new_op = builder.createOperation(state);
121     item.replaceAllUsesWith(new_op);
122     new_op->setAttrs(item.getAttrDictionary());
123     item.erase();
124   }
125 }
126 
127 // Populates an empty IslandOp and with a NoOp or Identity/IdentityN depending
128 // on if there are any data results.
PopulateEmptyIsland(tf_executor::IslandOp island)129 void PopulateEmptyIsland(tf_executor::IslandOp island) {
130   OpBuilder builder(&island.GetBody(), island.GetBody().begin());
131   tf_executor::YieldOp yield = island.GetYield();
132   if (yield.getNumOperands() == 0) {
133     builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
134   } else if (yield.getNumOperands() == 1) {
135     Value operand = yield.getOperand(0);
136     auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
137                                                    operand.getType(), operand);
138     yield.setOperand(0, identity.output());
139   } else {
140     auto identity_n = builder.create<TF::IdentityNOp>(
141         island.getLoc(), yield.getOperandTypes(), yield.getOperands());
142     for (auto it : llvm::enumerate(identity_n.getResults()))
143       yield.setOperand(it.index(), it.value());
144   }
145 }
146 
147 // Helper that creates an island. If `sub_op` is not nullptr, it will be moved
148 // to the island. Otherwise a NoOp will be added to the island.
CreateIsland(TypeRange result_types,ValueRange control_inputs,const tf_executor::ControlType & control_type,const Location & loc,Operation * sub_op,tf_executor::IslandOp original_island)149 tf_executor::IslandOp CreateIsland(TypeRange result_types,
150                                    ValueRange control_inputs,
151                                    const tf_executor::ControlType& control_type,
152                                    const Location& loc, Operation* sub_op,
153                                    tf_executor::IslandOp original_island) {
154   OpBuilder builder(original_island);
155   auto island = builder.create<tf_executor::IslandOp>(
156       loc, result_types, control_type, control_inputs);
157   island.body().push_back(new Block);
158   Block* block = &island.body().back();
159   OpBuilder island_builder(original_island);
160   island_builder.setInsertionPointToEnd(block);
161   if (sub_op) {
162     sub_op->replaceAllUsesWith(island.outputs());
163     sub_op->moveBefore(block, block->begin());
164     island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
165   } else {
166     island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
167     island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
168   }
169   return island;
170 }
171 
172 // A struct contains the operations in an island that do not have incoming or
173 // outgoing dependencies.
174 struct IslandSourcesAndSinks {
175   // Sub-ops that do not depend on other sub-ops in the island.
176   llvm::SmallPtrSet<Operation*, 4> sources;
177   // Sub-ops that do not have other sub-ops in the island depending on them
178   // (excluding yield).
179   llvm::SmallPtrSet<Operation*, 4> sinks;
180 };
181 
182 // Finds IslandSourcesAndSinks for an unmodified island.
FindSourcesAndSinksInIsland(tf_executor::IslandOp island,const TF::SideEffectAnalysis::Info & side_effect_analysis)183 IslandSourcesAndSinks FindSourcesAndSinksInIsland(
184     tf_executor::IslandOp island,
185     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
186   IslandSourcesAndSinks result;
187   auto island_body = island.GetBody().without_terminator();
188   for (Operation& sub_op : island_body) {
189     auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
190     result.sinks.insert(&sub_op);
191     // Remove predecessor from sinks.
192     for (auto predecessor : predecessors) result.sinks.erase(predecessor);
193     bool has_in_island_operands = false;
194     for (auto operand : sub_op.getOperands()) {
195       auto defining_op = operand.getDefiningOp();
196       if (!defining_op || defining_op->getParentOp() != island) continue;
197       // Remove operands from sinks.
198       result.sinks.erase(defining_op);
199       has_in_island_operands = true;
200     }
201     if (predecessors.empty() && !has_in_island_operands) {
202       result.sources.insert(&sub_op);
203     }
204   }
205   return result;
206 }
207 
208 // Converts a single island into multiple islands (one for each op). The islands
209 // are chained together by control flow values.
BreakUpIsland(tf_executor::IslandOp island_op,const TF::SideEffectAnalysis::Info & side_effect_analysis,llvm::DenseMap<Operation *,llvm::SmallVector<Value,4>> * new_control_inputs)210 void BreakUpIslands::BreakUpIsland(
211     tf_executor::IslandOp island_op,
212     const TF::SideEffectAnalysis::Info& side_effect_analysis,
213     llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
214         new_control_inputs) {
215   auto island_body = island_op.GetBody().without_terminator();
216   // Populate islands that are empty (only yield).
217   if (island_body.empty()) {
218     PopulateEmptyIsland(island_op);
219     return;
220   }
221 
222   // Skip islands that are already only a single op.
223   if (island_op.WrapsSingleOp()) return;
224 
225   auto control_type = tf_executor::ControlType::get(&getContext());
226   auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());
227   // Add control dependencies for yields of values defined by other islands to
228   // the island that defines that fetched value.
229   for (auto fetch : island_op.GetYield().fetches()) {
230     if (!fetch.getDefiningOp()) {
231       // Skip, because there is no op to add control to (eg: function args).
232       continue;
233     } else if (fetch.getDefiningOp()->getParentOp() == island_op) {
234       // Skip, because it is the same island.
235       continue;
236     } else if (auto other_island_op = llvm::dyn_cast<tf_executor::IslandOp>(
237                    fetch.getDefiningOp())) {
238       island_control_inputs.push_back(other_island_op.control());
239     } else {
240       // TODO(parkers): Any defining op that has a control output can be handled
241       // just like an island.
242       fetch.getDefiningOp()->emitError("fetching non-island as dependency");
243       return signalPassFailure();
244     }
245   }
246   // If there are multiple control inputs, create an empty island to group them.
247   if (island_control_inputs.size() > 1) {
248     auto new_island = CreateIsland({}, island_control_inputs, control_type,
249                                    island_op.getLoc(), nullptr, island_op);
250     island_control_inputs.clear();
251     island_control_inputs.push_back(new_island.control());
252   }
253   // Find sources and sinks inside the original island.
254   auto sources_and_sinks =
255       FindSourcesAndSinksInIsland(island_op, side_effect_analysis);
256   // The corresponding control output of the new island created for each sub-op.
257   llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;
258   // Control outputs of newly created islands that are sinks.
259   llvm::SmallVector<Value, 8> sink_island_controls;
260   // For each operation in the island, construct a new island to wrap the op,
261   // yield all the results, and replace all the usages with the results of the
262   // new island.
263   for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
264     const auto predecessors =
265         side_effect_analysis.DirectControlPredecessors(&sub_op);
266     // Get the controls from the predecessors.
267     llvm::SmallVector<Value, 4> predecessor_controls;
268     predecessor_controls.reserve(predecessors.size());
269     for (auto predecessor : predecessors) {
270       predecessor_controls.push_back(new_control_for_sub_ops[predecessor]);
271     }
272     // If sub_op is a source, use island_control_inputs, because that's required
273     // by inter-islands dependencies; otherwise, we do not need to include
274     // island_control_inputs, since they must have been tracked by the (direct
275     // or indirect) control predecessors or operands.
276     ArrayRef<Value> control = sources_and_sinks.sources.count(&sub_op) > 0
277                                   ? island_control_inputs
278                                   : predecessor_controls;
279     auto new_island =
280         CreateIsland(sub_op.getResultTypes(), control, control_type,
281                      sub_op.getLoc(), &sub_op, island_op);
282     new_control_for_sub_ops[&sub_op] = new_island.control();
283     if (sources_and_sinks.sinks.count(&sub_op)) {
284       sink_island_controls.push_back(new_island.control());
285     }
286   }
287   // Create control outputs for the sinks.
288   assert(!sink_island_controls.empty());
289   // If there are multiple control outputs, create an empty island to group
290   // them.
291   if (sink_island_controls.size() > 1) {
292     auto new_island = CreateIsland({}, sink_island_controls, control_type,
293                                    island_op.getLoc(), nullptr, island_op);
294     sink_island_controls.clear();
295     sink_island_controls.push_back(new_island.control());
296   }
297   assert(sink_island_controls.size() == 1);
298   auto& sink_island_control = sink_island_controls[0];
299   island_op.control().replaceAllUsesWith(sink_island_control);
300   // All existing outputs need to add sink_island_control as control input.
301   // GraphOp, YieldOp and NextIterationSourceOp don't have control inputs so
302   // exclude them below.
303   for (Value out : island_op.outputs()) {
304     for (auto& use : out.getUses()) {
305       Operation* owner = use.getOwner();
306       if (auto other_island_op =
307               llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
308         (*new_control_inputs)[other_island_op].push_back(sink_island_control);
309       } else if (owner->getDialect() == island_op->getDialect() &&
310                  !llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
311                             tf_executor::NextIterationSourceOp>(owner)) {
312         (*new_control_inputs)[owner].push_back(sink_island_control);
313       } else {
314         owner->emitOpError("adding control dependency not supported");
315         return signalPassFailure();
316       }
317     }
318   }
319   for (auto item :
320        llvm::zip(island_op.outputs(), island_op.GetYield().fetches()))
321     std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
322   island_op.erase();
323 }
324 
325 }  // namespace
326 
CreateBreakUpIslandsPass()327 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
328   return std::make_unique<BreakUpIslands>();
329 }
330 
331 }  // namespace mlir
332 
333 static mlir::PassRegistration<mlir::BreakUpIslands> pass(
334     "tf-executor-break-up-islands",
335     "Transform from TF control dialect to TF executor dialect.");
336