• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 
33 // This pass is used in preparation for Graph export.
34 // The GraphDef exporter expects each op to be in its own island.
35 // This pass puts the IR in that form.
36 //
37 // We do this as an IR->IR transform to keep the Graph exporter as simple as
38 // possible.
39 
40 namespace mlir {
41 
42 namespace {
43 
44 class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
45                            BreakUpIslands, TF::SideEffectAnalysis> {
getDependentDialects(DialectRegistry & registry) const46   void getDependentDialects(DialectRegistry& registry) const override {
47     registry.insert<tf_executor::TensorFlowExecutorDialect>();
48   }
49 
50  public:
getArgument() const51   StringRef getArgument() const final { return "tf-executor-break-up-islands"; }
52 
getDescription() const53   StringRef getDescription() const final {
54     return "Transform from TF control dialect to TF executor dialect.";
55   }
56 
57   void runOnFunction(FuncOp func,
58                      const TF::SideEffectAnalysis::Info& side_effect_analysis);
59 
60   void BreakUpIsland(tf_executor::IslandOp island_op,
61                      const TF::SideEffectAnalysis::Info& side_effect_analysis,
62                      llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
63                          new_control_inputs);
64 };
65 
runOnFunction(FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)66 void BreakUpIslands::runOnFunction(
67     FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
68   auto graph_op_range = func.front().without_terminator();
69   tf_executor::GraphOp graph_op;
70 
71   if (llvm::hasSingleElement(graph_op_range))
72     graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
73 
74   if (!graph_op) {
75     func.emitError("expected function to contain only a graph_op");
76     signalPassFailure();
77     return;
78   }
79 
80   // New control inputs to be added. For an operation x, new_control_inputs[x]
81   // contains all control inputs that need to be added to x as operands.
82   llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
83   // Iterate in reverse order to avoid invalidating Operation* stored in
84   // new_control_inputs.
85   for (auto& item :
86        llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
87     if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
88       BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
89     }
90   }
91   OpBuilder builder(func);
92 
93   // For every op, add new control inputs in reverse order so that the ops don't
94   // get invalidated.
95   llvm::SmallVector<Value, 8> operands;
96   llvm::SmallPtrSet<Operation*, 4> defining_ops;
97   llvm::SmallVector<Type, 4> types;
98   for (auto& item :
99        llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
100     auto it = new_control_inputs.find(&item);
101     if (it == new_control_inputs.end()) continue;
102     auto& new_control_inputs_for_item = it->second;
103     builder.setInsertionPoint(&item);
104     OperationState state(item.getLoc(), item.getName());
105     types.assign(item.result_type_begin(), item.result_type_end());
106     state.addTypes(types);
107     for (Region& region : item.getRegions()) {
108       state.addRegion()->takeBody(region);
109     }
110     // Assign existing operands for item.
111     operands.assign(item.operand_begin(), item.operand_end());
112 
113     // Collect defining ops for existing operands.
114     defining_ops.clear();
115     for (Value operand : operands) {
116       defining_ops.insert(operand.getDefiningOp());
117     }
118     for (Value new_control_input : llvm::reverse(new_control_inputs_for_item)) {
119       // Add new control input if its defining op is not already a defining
120       // op for some other operand. Update defining_ops.
121       if (defining_ops.insert(new_control_input.getDefiningOp()).second) {
122         operands.push_back(new_control_input);
123       }
124     }
125     state.addOperands(operands);
126     Operation* new_op = builder.createOperation(state);
127     item.replaceAllUsesWith(new_op);
128     new_op->setAttrs(item.getAttrDictionary());
129     item.erase();
130   }
131 }
132 
133 // Populates an empty IslandOp and with a NoOp or Identity/IdentityN depending
134 // on if there are any data results.
PopulateEmptyIsland(tf_executor::IslandOp island)135 void PopulateEmptyIsland(tf_executor::IslandOp island) {
136   OpBuilder builder(&island.GetBody(), island.GetBody().begin());
137   tf_executor::YieldOp yield = island.GetYield();
138   if (yield.getNumOperands() == 0) {
139     builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
140   } else if (yield.getNumOperands() == 1) {
141     Value operand = yield.getOperand(0);
142     auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
143                                                    operand.getType(), operand);
144     yield.setOperand(0, identity.output());
145   } else {
146     auto identity_n = builder.create<TF::IdentityNOp>(
147         island.getLoc(), yield.getOperandTypes(), yield.getOperands());
148     for (auto it : llvm::enumerate(identity_n.getResults()))
149       yield.setOperand(it.index(), it.value());
150   }
151 }
152 
153 // Helper that creates an island. If `sub_op` is not nullptr, it will be moved
154 // to the island. Otherwise a NoOp will be added to the island.
CreateIsland(TypeRange result_types,ValueRange control_inputs,const tf_executor::ControlType & control_type,const Location & loc,Operation * sub_op,tf_executor::IslandOp original_island)155 tf_executor::IslandOp CreateIsland(TypeRange result_types,
156                                    ValueRange control_inputs,
157                                    const tf_executor::ControlType& control_type,
158                                    const Location& loc, Operation* sub_op,
159                                    tf_executor::IslandOp original_island) {
160   OpBuilder builder(original_island);
161   auto island = builder.create<tf_executor::IslandOp>(
162       loc, result_types, control_type, control_inputs);
163   island.body().push_back(new Block);
164   Block* block = &island.body().back();
165   OpBuilder island_builder(original_island);
166   island_builder.setInsertionPointToEnd(block);
167   if (sub_op) {
168     sub_op->replaceAllUsesWith(island.outputs());
169     sub_op->moveBefore(block, block->begin());
170     island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
171   } else {
172     island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
173     island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
174   }
175   return island;
176 }
177 
178 // A struct that contains the operations in an island that need explicit control
179 // dependencies added going into and out of the island to capture inter-island
180 // dependencies properly.
181 struct IslandSourcesAndSinks {
182   // Sub-ops that need a control dependency going into the island. This includes
183   // sub-ops that do not depend on other sub-ops in the island and functional
184   // control ops (e.g. if, while, case) with side effects that must not take
185   // effect before the previous island is finished executing.
186   llvm::SmallPtrSet<Operation*, 4> sources;
187 
188   // Sub-ops that need a control dependency going out of the island. This
189   // includes sub-ops that do not have other sub-ops in the island depending on
190   // them (excluding yield) and functional control ops (e.g. if, while, case)
191   // with side effects that must take effect before the next island starts
192   // executing.
193   llvm::SmallPtrSet<Operation*, 4> sinks;
194 };
195 
196 // Returns true if the operation is a stateful If, Case, or While op.
IsStatefulFunctionalControlFlowOp(Operation * op)197 bool IsStatefulFunctionalControlFlowOp(Operation* op) {
198   if (!isa<TF::IfOp, TF::CaseOp, TF::WhileOp>(op)) {
199     return false;
200   }
201 
202   if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless")) {
203     return !is_stateless.getValue();
204   }
205   return false;
206 }
207 
208 // Finds IslandSourcesAndSinks for an unmodified island.
FindSourcesAndSinksInIsland(tf_executor::IslandOp island,const TF::SideEffectAnalysis::Info & side_effect_analysis)209 IslandSourcesAndSinks FindSourcesAndSinksInIsland(
210     tf_executor::IslandOp island,
211     const TF::SideEffectAnalysis::Info& side_effect_analysis) {
212   IslandSourcesAndSinks result;
213   auto island_body = island.GetBody().without_terminator();
214   for (Operation& sub_op : island_body) {
215     auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
216     result.sinks.insert(&sub_op);
217     // Remove predecessor from sinks.
218     for (auto predecessor : predecessors) result.sinks.erase(predecessor);
219     bool has_in_island_operands = false;
220     for (auto operand : sub_op.getOperands()) {
221       auto defining_op = operand.getDefiningOp();
222       if (!defining_op || defining_op->getParentOp() != island) continue;
223       has_in_island_operands = true;
224 
225       // Remove operands from sinks.
226       // We don't remove the operand if it is a stateful functional control flow
227       // op to work around an issue in LowerFunctionalOpsPass where the operand
228       // dependency isn't enough to ensure the side effects take place
229       // (b/185483669).
230       if (!IsStatefulFunctionalControlFlowOp(defining_op)) {
231         result.sinks.erase(defining_op);
232       }
233     }
234     if (predecessors.empty() && (!has_in_island_operands ||
235                                  IsStatefulFunctionalControlFlowOp(&sub_op))) {
236       result.sources.insert(&sub_op);
237     }
238   }
239   return result;
240 }
241 
242 // Converts a single island into multiple islands (one for each op). The islands
243 // are chained together by control flow values.
BreakUpIsland(tf_executor::IslandOp island_op,const TF::SideEffectAnalysis::Info & side_effect_analysis,llvm::DenseMap<Operation *,llvm::SmallVector<Value,4>> * new_control_inputs)244 void BreakUpIslands::BreakUpIsland(
245     tf_executor::IslandOp island_op,
246     const TF::SideEffectAnalysis::Info& side_effect_analysis,
247     llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
248         new_control_inputs) {
249   auto island_body = island_op.GetBody().without_terminator();
250   // Populate islands that are empty (only yield).
251   if (island_body.empty()) {
252     PopulateEmptyIsland(island_op);
253     return;
254   }
255 
256   // Skip islands that are already only a single op.
257   if (island_op.WrapsSingleOp()) return;
258 
259   auto control_type = tf_executor::ControlType::get(&getContext());
260   auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());
261   // Add control dependencies for yields of values defined by other islands to
262   // the island that defines that fetched value.
263   for (auto fetch : island_op.GetYield().fetches()) {
264     if (!fetch.getDefiningOp()) {
265       // Skip, because there is no op to add control to (eg: function args).
266       continue;
267     } else if (fetch.getDefiningOp()->getParentOp() == island_op) {
268       // Skip, because it is the same island.
269       continue;
270     } else if (auto other_island_op = llvm::dyn_cast<tf_executor::IslandOp>(
271                    fetch.getDefiningOp())) {
272       island_control_inputs.push_back(other_island_op.control());
273     } else {
274       // TODO(parkers): Any defining op that has a control output can be handled
275       // just like an island.
276       fetch.getDefiningOp()->emitError("fetching non-island as dependency");
277       return signalPassFailure();
278     }
279   }
280   // If there are multiple control inputs, create an empty island to group them.
281   if (island_control_inputs.size() > 1) {
282     auto new_island = CreateIsland({}, island_control_inputs, control_type,
283                                    island_op.getLoc(), nullptr, island_op);
284     island_control_inputs.clear();
285     island_control_inputs.push_back(new_island.control());
286   }
287   // Find sources and sinks inside the original island.
288   IslandSourcesAndSinks sources_and_sinks =
289       FindSourcesAndSinksInIsland(island_op, side_effect_analysis);
290   // The corresponding control output of the new island created for each sub-op.
291   llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;
292   // Control outputs of newly created islands that are sinks.
293   llvm::SmallVector<Value, 8> sink_island_controls;
294   // For each operation in the island, construct a new island to wrap the op,
295   // yield all the results, and replace all the usages with the results of the
296   // new island.
297   for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
298     const auto predecessors =
299         side_effect_analysis.DirectControlPredecessors(&sub_op);
300     // Get the controls from the predecessors.
301     llvm::SmallVector<Value, 4> predecessor_controls;
302     predecessor_controls.reserve(predecessors.size());
303     for (auto predecessor : predecessors) {
304       predecessor_controls.push_back(new_control_for_sub_ops[predecessor]);
305     }
306     // If sub_op is a source, use island_control_inputs, because that's required
307     // by inter-islands dependencies; otherwise, we do not need to include
308     // island_control_inputs, since they must have been tracked by the (direct
309     // or indirect) control predecessors or operands.
310     ArrayRef<Value> control = sources_and_sinks.sources.count(&sub_op) > 0
311                                   ? island_control_inputs
312                                   : predecessor_controls;
313     auto new_island =
314         CreateIsland(sub_op.getResultTypes(), control, control_type,
315                      sub_op.getLoc(), &sub_op, island_op);
316     new_control_for_sub_ops[&sub_op] = new_island.control();
317     if (sources_and_sinks.sinks.count(&sub_op)) {
318       sink_island_controls.push_back(new_island.control());
319     }
320   }
321   // Create control outputs for the sinks.
322   assert(!sink_island_controls.empty());
323   // If there are multiple control outputs, create an empty island to group
324   // them.
325   if (sink_island_controls.size() > 1) {
326     auto new_island = CreateIsland({}, sink_island_controls, control_type,
327                                    island_op.getLoc(), nullptr, island_op);
328     sink_island_controls.clear();
329     sink_island_controls.push_back(new_island.control());
330   }
331   assert(sink_island_controls.size() == 1);
332   auto& sink_island_control = sink_island_controls[0];
333   island_op.control().replaceAllUsesWith(sink_island_control);
334   // All existing outputs need to add sink_island_control as control input.
335   // GraphOp, YieldOp and NextIterationSourceOp don't have control inputs so
336   // exclude them below.
337   for (Value out : island_op.outputs()) {
338     for (auto& use : out.getUses()) {
339       Operation* owner = use.getOwner();
340       if (auto other_island_op =
341               llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
342         (*new_control_inputs)[other_island_op].push_back(sink_island_control);
343       } else if (owner->getDialect() == island_op->getDialect() &&
344                  !llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
345                             tf_executor::NextIterationSourceOp>(owner)) {
346         (*new_control_inputs)[owner].push_back(sink_island_control);
347       } else {
348         owner->emitOpError("adding control dependency not supported");
349         return signalPassFailure();
350       }
351     }
352   }
353   for (auto item :
354        llvm::zip(island_op.outputs(), island_op.GetYield().fetches()))
355     std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
356   island_op.erase();
357 }
358 
359 }  // namespace
360 
CreateBreakUpIslandsPass()361 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
362   return std::make_unique<BreakUpIslands>();
363 }
364 
365 }  // namespace mlir
366 
367 static mlir::PassRegistration<mlir::BreakUpIslands> pass;
368