1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32
33 // This pass is used in preparation for Graph export.
34 // The GraphDef exporter expects each op to be in its own island.
35 // This pass puts the IR in that form.
36 //
37 // We do this as an IR->IR transform to keep the Graph exporter as simple as
38 // possible.
39
40 namespace mlir {
41
42 namespace {
43
44 class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass<
45 BreakUpIslands, TF::SideEffectAnalysis> {
getDependentDialects(DialectRegistry & registry) const46 void getDependentDialects(DialectRegistry& registry) const override {
47 registry.insert<tf_executor::TensorFlowExecutorDialect>();
48 }
49
50 public:
51 void runOnFunction(FuncOp func,
52 const TF::SideEffectAnalysis::Info& side_effect_analysis);
53
54 void BreakUpIsland(tf_executor::IslandOp island_op,
55 const TF::SideEffectAnalysis::Info& side_effect_analysis,
56 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
57 new_control_inputs);
58 };
59
runOnFunction(FuncOp func,const TF::SideEffectAnalysis::Info & side_effect_analysis)60 void BreakUpIslands::runOnFunction(
61 FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
62 auto graph_op_range = func.front().without_terminator();
63 tf_executor::GraphOp graph_op;
64
65 if (llvm::hasSingleElement(graph_op_range))
66 graph_op = dyn_cast<tf_executor::GraphOp>(func.front().front());
67
68 if (!graph_op) {
69 func.emitError("expected function to contain only a graph_op");
70 signalPassFailure();
71 return;
72 }
73
74 // New control inputs to be added. For an operation x, new_control_inputs[x]
75 // contains all control inputs that need to be added to x as operands.
76 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>> new_control_inputs;
77 // Iterate in reverse order to avoid invalidating Operation* stored in
78 // new_control_inputs.
79 for (auto& item :
80 llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
81 if (auto island = dyn_cast<tf_executor::IslandOp>(&item)) {
82 BreakUpIsland(island, side_effect_analysis, &new_control_inputs);
83 }
84 }
85 OpBuilder builder(func);
86
87 // For every op, add new control inputs in reverse order so that the ops don't
88 // get invalidated.
89 llvm::SmallVector<Value, 8> operands;
90 llvm::SmallPtrSet<Operation*, 4> defining_ops;
91 llvm::SmallVector<Type, 4> types;
92 for (auto& item :
93 llvm::make_early_inc_range(llvm::reverse(graph_op.GetBody()))) {
94 auto it = new_control_inputs.find(&item);
95 if (it == new_control_inputs.end()) continue;
96 auto& new_control_inputs_for_item = it->second;
97 builder.setInsertionPoint(&item);
98 OperationState state(item.getLoc(), item.getName());
99 types.assign(item.result_type_begin(), item.result_type_end());
100 state.addTypes(types);
101 for (Region& region : item.getRegions()) {
102 state.addRegion()->takeBody(region);
103 }
104 // Assign existing operands for item.
105 operands.assign(item.operand_begin(), item.operand_end());
106
107 // Collect defining ops for existing operands.
108 defining_ops.clear();
109 for (Value operand : operands) {
110 defining_ops.insert(operand.getDefiningOp());
111 }
112 for (Value new_control_input : llvm::reverse(new_control_inputs_for_item)) {
113 // Add new control input if its defining op is not already a defining
114 // op for some other operand. Update defining_ops.
115 if (defining_ops.insert(new_control_input.getDefiningOp()).second) {
116 operands.push_back(new_control_input);
117 }
118 }
119 state.addOperands(operands);
120 Operation* new_op = builder.createOperation(state);
121 item.replaceAllUsesWith(new_op);
122 new_op->setAttrs(item.getAttrDictionary());
123 item.erase();
124 }
125 }
126
127 // Populates an empty IslandOp and with a NoOp or Identity/IdentityN depending
128 // on if there are any data results.
PopulateEmptyIsland(tf_executor::IslandOp island)129 void PopulateEmptyIsland(tf_executor::IslandOp island) {
130 OpBuilder builder(&island.GetBody(), island.GetBody().begin());
131 tf_executor::YieldOp yield = island.GetYield();
132 if (yield.getNumOperands() == 0) {
133 builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
134 } else if (yield.getNumOperands() == 1) {
135 Value operand = yield.getOperand(0);
136 auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
137 operand.getType(), operand);
138 yield.setOperand(0, identity.output());
139 } else {
140 auto identity_n = builder.create<TF::IdentityNOp>(
141 island.getLoc(), yield.getOperandTypes(), yield.getOperands());
142 for (auto it : llvm::enumerate(identity_n.getResults()))
143 yield.setOperand(it.index(), it.value());
144 }
145 }
146
147 // Helper that creates an island. If `sub_op` is not nullptr, it will be moved
148 // to the island. Otherwise a NoOp will be added to the island.
CreateIsland(TypeRange result_types,ValueRange control_inputs,const tf_executor::ControlType & control_type,const Location & loc,Operation * sub_op,tf_executor::IslandOp original_island)149 tf_executor::IslandOp CreateIsland(TypeRange result_types,
150 ValueRange control_inputs,
151 const tf_executor::ControlType& control_type,
152 const Location& loc, Operation* sub_op,
153 tf_executor::IslandOp original_island) {
154 OpBuilder builder(original_island);
155 auto island = builder.create<tf_executor::IslandOp>(
156 loc, result_types, control_type, control_inputs);
157 island.body().push_back(new Block);
158 Block* block = &island.body().back();
159 OpBuilder island_builder(original_island);
160 island_builder.setInsertionPointToEnd(block);
161 if (sub_op) {
162 sub_op->replaceAllUsesWith(island.outputs());
163 sub_op->moveBefore(block, block->begin());
164 island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
165 } else {
166 island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
167 island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
168 }
169 return island;
170 }
171
172 // A struct contains the operations in an island that do not have incoming or
173 // outgoing dependencies.
174 struct IslandSourcesAndSinks {
175 // Sub-ops that do not depend on other sub-ops in the island.
176 llvm::SmallPtrSet<Operation*, 4> sources;
177 // Sub-ops that do not have other sub-ops in the island depending on them
178 // (excluding yield).
179 llvm::SmallPtrSet<Operation*, 4> sinks;
180 };
181
182 // Finds IslandSourcesAndSinks for an unmodified island.
FindSourcesAndSinksInIsland(tf_executor::IslandOp island,const TF::SideEffectAnalysis::Info & side_effect_analysis)183 IslandSourcesAndSinks FindSourcesAndSinksInIsland(
184 tf_executor::IslandOp island,
185 const TF::SideEffectAnalysis::Info& side_effect_analysis) {
186 IslandSourcesAndSinks result;
187 auto island_body = island.GetBody().without_terminator();
188 for (Operation& sub_op : island_body) {
189 auto predecessors = side_effect_analysis.DirectControlPredecessors(&sub_op);
190 result.sinks.insert(&sub_op);
191 // Remove predecessor from sinks.
192 for (auto predecessor : predecessors) result.sinks.erase(predecessor);
193 bool has_in_island_operands = false;
194 for (auto operand : sub_op.getOperands()) {
195 auto defining_op = operand.getDefiningOp();
196 if (!defining_op || defining_op->getParentOp() != island) continue;
197 // Remove operands from sinks.
198 result.sinks.erase(defining_op);
199 has_in_island_operands = true;
200 }
201 if (predecessors.empty() && !has_in_island_operands) {
202 result.sources.insert(&sub_op);
203 }
204 }
205 return result;
206 }
207
208 // Converts a single island into multiple islands (one for each op). The islands
209 // are chained together by control flow values.
BreakUpIsland(tf_executor::IslandOp island_op,const TF::SideEffectAnalysis::Info & side_effect_analysis,llvm::DenseMap<Operation *,llvm::SmallVector<Value,4>> * new_control_inputs)210 void BreakUpIslands::BreakUpIsland(
211 tf_executor::IslandOp island_op,
212 const TF::SideEffectAnalysis::Info& side_effect_analysis,
213 llvm::DenseMap<Operation*, llvm::SmallVector<Value, 4>>*
214 new_control_inputs) {
215 auto island_body = island_op.GetBody().without_terminator();
216 // Populate islands that are empty (only yield).
217 if (island_body.empty()) {
218 PopulateEmptyIsland(island_op);
219 return;
220 }
221
222 // Skip islands that are already only a single op.
223 if (island_op.WrapsSingleOp()) return;
224
225 auto control_type = tf_executor::ControlType::get(&getContext());
226 auto island_control_inputs = llvm::to_vector<4>(island_op.controlInputs());
227 // Add control dependencies for yields of values defined by other islands to
228 // the island that defines that fetched value.
229 for (auto fetch : island_op.GetYield().fetches()) {
230 if (!fetch.getDefiningOp()) {
231 // Skip, because there is no op to add control to (eg: function args).
232 continue;
233 } else if (fetch.getDefiningOp()->getParentOp() == island_op) {
234 // Skip, because it is the same island.
235 continue;
236 } else if (auto other_island_op = llvm::dyn_cast<tf_executor::IslandOp>(
237 fetch.getDefiningOp())) {
238 island_control_inputs.push_back(other_island_op.control());
239 } else {
240 // TODO(parkers): Any defining op that has a control output can be handled
241 // just like an island.
242 fetch.getDefiningOp()->emitError("fetching non-island as dependency");
243 return signalPassFailure();
244 }
245 }
246 // If there are multiple control inputs, create an empty island to group them.
247 if (island_control_inputs.size() > 1) {
248 auto new_island = CreateIsland({}, island_control_inputs, control_type,
249 island_op.getLoc(), nullptr, island_op);
250 island_control_inputs.clear();
251 island_control_inputs.push_back(new_island.control());
252 }
253 // Find sources and sinks inside the original island.
254 auto sources_and_sinks =
255 FindSourcesAndSinksInIsland(island_op, side_effect_analysis);
256 // The corresponding control output of the new island created for each sub-op.
257 llvm::SmallDenseMap<Operation*, Value, 8> new_control_for_sub_ops;
258 // Control outputs of newly created islands that are sinks.
259 llvm::SmallVector<Value, 8> sink_island_controls;
260 // For each operation in the island, construct a new island to wrap the op,
261 // yield all the results, and replace all the usages with the results of the
262 // new island.
263 for (auto& sub_op : llvm::make_early_inc_range(island_body)) {
264 const auto predecessors =
265 side_effect_analysis.DirectControlPredecessors(&sub_op);
266 // Get the controls from the predecessors.
267 llvm::SmallVector<Value, 4> predecessor_controls;
268 predecessor_controls.reserve(predecessors.size());
269 for (auto predecessor : predecessors) {
270 predecessor_controls.push_back(new_control_for_sub_ops[predecessor]);
271 }
272 // If sub_op is a source, use island_control_inputs, because that's required
273 // by inter-islands dependencies; otherwise, we do not need to include
274 // island_control_inputs, since they must have been tracked by the (direct
275 // or indirect) control predecessors or operands.
276 ArrayRef<Value> control = sources_and_sinks.sources.count(&sub_op) > 0
277 ? island_control_inputs
278 : predecessor_controls;
279 auto new_island =
280 CreateIsland(sub_op.getResultTypes(), control, control_type,
281 sub_op.getLoc(), &sub_op, island_op);
282 new_control_for_sub_ops[&sub_op] = new_island.control();
283 if (sources_and_sinks.sinks.count(&sub_op)) {
284 sink_island_controls.push_back(new_island.control());
285 }
286 }
287 // Create control outputs for the sinks.
288 assert(!sink_island_controls.empty());
289 // If there are multiple control outputs, create an empty island to group
290 // them.
291 if (sink_island_controls.size() > 1) {
292 auto new_island = CreateIsland({}, sink_island_controls, control_type,
293 island_op.getLoc(), nullptr, island_op);
294 sink_island_controls.clear();
295 sink_island_controls.push_back(new_island.control());
296 }
297 assert(sink_island_controls.size() == 1);
298 auto& sink_island_control = sink_island_controls[0];
299 island_op.control().replaceAllUsesWith(sink_island_control);
300 // All existing outputs need to add sink_island_control as control input.
301 // GraphOp, YieldOp and NextIterationSourceOp don't have control inputs so
302 // exclude them below.
303 for (Value out : island_op.outputs()) {
304 for (auto& use : out.getUses()) {
305 Operation* owner = use.getOwner();
306 if (auto other_island_op =
307 llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
308 (*new_control_inputs)[other_island_op].push_back(sink_island_control);
309 } else if (owner->getDialect() == island_op->getDialect() &&
310 !llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
311 tf_executor::NextIterationSourceOp>(owner)) {
312 (*new_control_inputs)[owner].push_back(sink_island_control);
313 } else {
314 owner->emitOpError("adding control dependency not supported");
315 return signalPassFailure();
316 }
317 }
318 }
319 for (auto item :
320 llvm::zip(island_op.outputs(), island_op.GetYield().fetches()))
321 std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
322 island_op.erase();
323 }
324
325 } // namespace
326
CreateBreakUpIslandsPass()327 std::unique_ptr<OperationPass<ModuleOp>> CreateBreakUpIslandsPass() {
328 return std::make_unique<BreakUpIslands>();
329 }
330
331 } // namespace mlir
332
333 static mlir::PassRegistration<mlir::BreakUpIslands> pass(
334 "tf-executor-break-up-islands",
335 "Transform from TF control dialect to TF executor dialect.");
336