1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32
33 // This pass is used in preparation for Graph export.
34 // The GraphDef exporter expects each op to be in its own island.
35 // This pass puts the IR in that form.
36 //
37 // We do this as an IR->IR transform to keep the Graph exporter as simple as
38 // possible.
39
40 namespace mlir {
41
42 namespace {
43
44 class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
45 BreakUpIslands, TF::SideEffectAnalysis> {
getDependentDialects(DialectRegistry & registry) const46 void getDependentDialects(DialectRegistry& registry) const override {
47 registry.insert<tf_executor::TensorFlowExecutorDialect>();
48 }
49
50 public:
getArgument() const51 StringRef getArgument() const final { return "tf-executor-break-up-islands"; }
52
getDescription() const53 StringRef getDescription() const final {
54 return "Transform from TF control dialect to TF executor dialect.";
55 }
56
57 void runOnFunction(FuncOp func,
58 const TF::SideEffectAnalysis::Info& side_effect_analysis);
59
60 void BreakUpIsland(tf_executor::IslandOp island_op,
61 const TF::SideEffectAnalysis::Info& side_effect_analysis,
62 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
63 new_control_inputs);
64 };
65
runOnFunction(FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)66 void BreakUpIslands::runOnFunction(
67 FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
68 auto graph_op_range = func.front().without_terminator();
69 tf_executor::GraphOp graph_op;
70
71 if (llvm::hasSingleElement(graph_op_range))
72 graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
73
74 if (!graph_op) {
75 func.emitError("expected function to contain only a graph_op");
76 signalPassFailure();
77 return;
78 }
79
80 // New control inputs to be added. For an operation x, new_control_inputs[x]
81 // contains all control inputs that need to be added to x as operands.
82 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
83 // Iterate in reverse order to avoid invalidating Operation* stored in
84 // new_control_inputs.
85 for (auto& item :
86 llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
87 if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
88 BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
89 }
90 }
91 OpBuilder builder(func);
92
93 // For every op, add new control inputs in reverse order so that the ops don't
94 // get invalidated.
95 llvm::SmallVector<Value, 8> operands;
96 llvm::SmallPtrSet<Operation*, 4> defining_ops;
97 llvm::SmallVector<Type, 4> types;
98 for (auto& item :
99 llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
100 auto it = new_control_inputs.find(&item);
101 if (it == new_control_inputs.end()) continue;
102 auto& new_control_inputs_for_item = it->second;
103 builder.setInsertionPoint(&item);
104 OperationState state(item.getLoc(), item.getName());
105 types.assign(item.result_type_begin(), item.result_type_end());
106 state.addTypes(types);
107 for (Region& region : item.getRegions()) {
108 state.addRegion()->takeBody(region);
109 }
110 // Assign existing operands for item.
111 operands.assign(item.operand_begin(), item.operand_end());
112
113 // Collect defining ops for existing operands.
114 defining_ops.clear();
115 for (Value operand : operands) {
116 defining_ops.insert(operand.getDefiningOp());
117 }
118 for (Value new_control_input : llvm::reverse(new_control_inputs_for_item)) {
119 // Add new control input if its defining op is not already a defining
120 // op for some other operand. Update defining_ops.
121 if (defining_ops.insert(new_control_input.getDefiningOp()).second) {
122 operands.push_back(new_control_input);
123 }
124 }
125 state.addOperands(operands);
126 Operation* new_op = builder.createOperation(state);
127 item.replaceAllUsesWith(new_op);
128 new_op->setAttrs(item.getAttrDictionary());
129 item.erase();
130 }
131 }
132
133 // Populates an empty IslandOp and with a NoOp or Identity/IdentityN depending
134 // on if there are any data results.
PopulateEmptyIsland(tf_executor::IslandOp island)135 void PopulateEmptyIsland(tf_executor::IslandOp island) {
136 OpBuilder builder(&island.GetBody(), island.GetBody().begin());
137 tf_executor::YieldOp yield = island.GetYield();
138 if (yield.getNumOperands() == 0) {
139 builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
140 } else if (yield.getNumOperands() == 1) {
141 Value operand = yield.getOperand(0);
142 auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
143 operand.getType(), operand);
144 yield.setOperand(0, identity.output());
145 } else {
146 auto identity_n = builder.create<TF::IdentityNOp>(
147 island.getLoc(), yield.getOperandTypes(), yield.getOperands());
148 for (auto it : llvm::enumerate(identity_n.getResults()))
149 yield.setOperand(it.index(), it.value());
150 }
151 }
152
153 // Helper that creates an island. If `sub_op` is not nullptr, it will be moved
154 // to the island. Otherwise a NoOp will be added to the island.
CreateIsland(TypeRange result_types,ValueRange control_inputs,const tf_executor::ControlType & control_type,const Location & loc,Operation * sub_op,tf_executor::IslandOp original_island)155 tf_executor::IslandOp CreateIsland(TypeRange result_types,
156 ValueRange control_inputs,
157 const tf_executor::ControlType& control_type,
158 const Location& loc, Operation* sub_op,
159 tf_executor::IslandOp original_island) {
160 OpBuilder builder(original_island);
161 auto island = builder.create<tf_executor::IslandOp>(
162 loc, result_types, control_type, control_inputs);
163 island.body().push_back(new Block);
164 Block* block = &island.body().back();
165 OpBuilder island_builder(original_island);
166 island_builder.setInsertionPointToEnd(block);
167 if (sub_op) {
168 sub_op->replaceAllUsesWith(island.outputs());
169 sub_op->moveBefore(block, block->begin());
170 island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
171 } else {
172 island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
173 island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
174 }
175 return island;
176 }
177
178 // A struct that contains the operations in an island that need explicit control
179 // dependencies added going into and out of the island to capture inter-island
180 // dependencies properly.
181 struct IslandSourcesAndSinks {
182 // Sub-ops that need a control dependency going into the island. This includes
183 // sub-ops that do not depend on other sub-ops in the island and functional
184 // control ops (e.g. if, while, case) with side effects that must not take
185 // effect before the previous island is finished executing.
186 llvm::SmallPtrSet<Operation*, 4> sources;
187
188 // Sub-ops that need a control dependency going out of the island. This
189 // includes sub-ops that do not have other sub-ops in the island depending on
190 // them (excluding yield) and functional control ops (e.g. if, while, case)
191 // with side effects that must take effect before the next island starts
192 // executing.
193 llvm::SmallPtrSet<Operation*, 4> sinks;
194 };
195
196 // Returns true if the operation is a stateful If, Case, or While op.
IsStatefulFunctionalControlFlowOp(Operation * op)197 bool IsStatefulFunctionalControlFlowOp(Operation* op) {
198 if (!isa<TF::IfOp, TF::CaseOp, TF::WhileOp>(op)) {
199 return false;
200 }
201
202 if (auto is_stateless = op->getAttrOfType<BoolAttr>("is_stateless")) {
203 return !is_stateless.getValue();
204 }
205 return false;
206 }
207
208 // Finds IslandSourcesAndSinks for an unmodified island.
FindSourcesAndSinksInIsland(tf_executor::IslandOp island,const TF::SideEffectAnalysis::Info & side_effect_analysis)209 IslandSourcesAndSinks FindSourcesAndSinksInIsland(
210 tf_executor::IslandOp island,
211 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
212 IslandSourcesAndSinks result;
213 auto island_body = island.GetBody().without_terminator();
214 for (Operation& sub_op : island_body) {
215 auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
216 result.sinks.insert(&sub_op);
217 // Remove predecessor from sinks.
218 for (auto predecessor : predecessors) result.sinks.erase(predecessor);
219 bool has_in_island_operands = false;
220 for (auto operand : sub_op.getOperands()) {
221 auto defining_op = operand.getDefiningOp();
222 if (!defining_op || defining_op->getParentOp() != island) continue;
223 has_in_island_operands = true;
224
225 // Remove operands from sinks.
226 // We don't remove the operand if it is a stateful functional control flow
227 // op to work around an issue in LowerFunctionalOpsPass where the operand
228 // dependency isn't enough to ensure the side effects take place
229 // (b/185483669).
230 if (!IsStatefulFunctionalControlFlowOp(defining_op)) {
231 result.sinks.erase(defining_op);
232 }
233 }
234 if (predecessors.empty() && (!has_in_island_operands ||
235 IsStatefulFunctionalControlFlowOp(&sub_op))) {
236 result.sources.insert(&sub_op);
237 }
238 }
239 return result;
240 }
241
242 // Converts a single island into multiple islands (one for each op). The islands
243 // are chained together by control flow values.
BreakUpIsland(tf_executor::IslandOp island_op,const TF::SideEffectAnalysis::Info & side_effect_analysis,llvm::DenseMap<Operation *,llvm::SmallVector<Value,4>> * new_control_inputs)244 void BreakUpIslands::BreakUpIsland(
245 tf_executor::IslandOp island_op,
246 const TF::SideEffectAnalysis::Info& side_effect_analysis,
247 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
248 new_control_inputs) {
249 auto island_body = island_op.GetBody().without_terminator();
250 // Populate islands that are empty (only yield).
251 if (island_body.empty()) {
252 PopulateEmptyIsland(island_op);
253 return;
254 }
255
256 // Skip islands that are already only a single op.
257 if (island_op.WrapsSingleOp()) return;
258
259 auto control_type = tf_executor::ControlType::get(&getContext());
260 auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());
261 // Add control dependencies for yields of values defined by other islands to
262 // the island that defines that fetched value.
263 for (auto fetch : island_op.GetYield().fetches()) {
264 if (!fetch.getDefiningOp()) {
265 // Skip, because there is no op to add control to (eg: function args).
266 continue;
267 } else if (fetch.getDefiningOp()->getParentOp() == island_op) {
268 // Skip, because it is the same island.
269 continue;
270 } else if (auto other_island_op = llvm::dyn_cast<tf_executor::IslandOp>(
271 fetch.getDefiningOp())) {
272 island_control_inputs.push_back(other_island_op.control());
273 } else {
274 // TODO(parkers): Any defining op that has a control output can be handled
275 // just like an island.
276 fetch.getDefiningOp()->emitError("fetching non-island as dependency");
277 return signalPassFailure();
278 }
279 }
280 // If there are multiple control inputs, create an empty island to group them.
281 if (island_control_inputs.size() > 1) {
282 auto new_island = CreateIsland({}, island_control_inputs, control_type,
283 island_op.getLoc(), nullptr, island_op);
284 island_control_inputs.clear();
285 island_control_inputs.push_back(new_island.control());
286 }
287 // Find sources and sinks inside the original island.
288 IslandSourcesAndSinks sources_and_sinks =
289 FindSourcesAndSinksInIsland(island_op, side_effect_analysis);
290 // The corresponding control output of the new island created for each sub-op.
291 llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;
292 // Control outputs of newly created islands that are sinks.
293 llvm::SmallVector<Value, 8> sink_island_controls;
294 // For each operation in the island, construct a new island to wrap the op,
295 // yield all the results, and replace all the usages with the results of the
296 // new island.
297 for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
298 const auto predecessors =
299 side_effect_analysis.DirectControlPredecessors(&sub_op);
300 // Get the controls from the predecessors.
301 llvm::SmallVector<Value, 4> predecessor_controls;
302 predecessor_controls.reserve(predecessors.size());
303 for (auto predecessor : predecessors) {
304 predecessor_controls.push_back(new_control_for_sub_ops[predecessor]);
305 }
306 // If sub_op is a source, use island_control_inputs, because that's required
307 // by inter-islands dependencies; otherwise, we do not need to include
308 // island_control_inputs, since they must have been tracked by the (direct
309 // or indirect) control predecessors or operands.
310 ArrayRef<Value> control = sources_and_sinks.sources.count(&sub_op) > 0
311 ? island_control_inputs
312 : predecessor_controls;
313 auto new_island =
314 CreateIsland(sub_op.getResultTypes(), control, control_type,
315 sub_op.getLoc(), &sub_op, island_op);
316 new_control_for_sub_ops[&sub_op] = new_island.control();
317 if (sources_and_sinks.sinks.count(&sub_op)) {
318 sink_island_controls.push_back(new_island.control());
319 }
320 }
321 // Create control outputs for the sinks.
322 assert(!sink_island_controls.empty());
323 // If there are multiple control outputs, create an empty island to group
324 // them.
325 if (sink_island_controls.size() > 1) {
326 auto new_island = CreateIsland({}, sink_island_controls, control_type,
327 island_op.getLoc(), nullptr, island_op);
328 sink_island_controls.clear();
329 sink_island_controls.push_back(new_island.control());
330 }
331 assert(sink_island_controls.size() == 1);
332 auto& sink_island_control = sink_island_controls[0];
333 island_op.control().replaceAllUsesWith(sink_island_control);
334 // All existing outputs need to add sink_island_control as control input.
335 // GraphOp, YieldOp and NextIterationSourceOp don't have control inputs so
336 // exclude them below.
337 for (Value out : island_op.outputs()) {
338 for (auto& use : out.getUses()) {
339 Operation* owner = use.getOwner();
340 if (auto other_island_op =
341 llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
342 (*new_control_inputs)[other_island_op].push_back(sink_island_control);
343 } else if (owner->getDialect() == island_op->getDialect() &&
344 !llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
345 tf_executor::NextIterationSourceOp>(owner)) {
346 (*new_control_inputs)[owner].push_back(sink_island_control);
347 } else {
348 owner->emitOpError("adding control dependency not supported");
349 return signalPassFailure();
350 }
351 }
352 }
353 for (auto item :
354 llvm::zip(island_op.outputs(), island_op.GetYield().fetches()))
355 std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
356 island_op.erase();
357 }
358
359 } // namespace
360
CreateBreakUpIslandsPass()361 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
362 return std::make_unique<BreakUpIslands>();
363 }
364
365 } // namespace mlir
366
367 static mlir::PassRegistration<mlir::BreakUpIslands> pass;
368