1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass takes TensorFlow executor dialect IslandOps and
17 // merges them. Note, this currently does not handle TensorFlow V1 style control
18 // flow/frames or side effecting ops yet.
19
20 #include <iterator>
21 #include <tuple>
22
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/None.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Block.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/Operation.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
39 #include "tensorflow/core/platform/logging.h"
40
41 namespace mlir {
42 namespace tf_executor {
43
44 namespace {
45
46 //===----------------------------------------------------------------------===//
47 // Analysis
48 //===----------------------------------------------------------------------===//
49
50 // This structure represents a merged island. It includes all of the islands
51 // that can be merged together and the point of insertion of the merged island.
52 struct MergedIsland {
53 // Construct a new island from the given root.
MergedIslandmlir::tf_executor::__anonc689b1ec0111::MergedIsland54 explicit MergedIsland(IslandOp root) : insert_point(root) {
55 islands.push_back(root);
56 }
57
58 // The insertion point anchor of the merged island, or where the merged island
59 // will be inserted when created.
60 Operation* const insert_point;
61
62 // The set of islands that are to be merged together.
63 SmallVector<IslandOp> islands;
64 };
65
66 // This structure contains all of the merge decisions for islands within a
67 // graph. We compute which islands to merge first, so that we don't need to
68 // recursively mutate the IR (resulting in quadratic behavior when moving
69 // operations). A rough sketch of the coarsening algorithm is shown below:
70 //
71 // // The algorithm iterates until a fixpoint is reached, i.e. when no more
72 // // islands can be merged.
73 // while (changed) {
74 // // In the first phase we try to merge islands with their nearest consumer
75 // // iff the consumer is another island.
76 // // Note: A consumer is an operation that consumes one of our outputs.
77 // changed |= tryMergedIslandsIntoNearestConsumer();
78 //
79 // // In the second phase we try to merge islands with their nearest producer
80 // // of a value they consume, iff the producer is another island.
81 // // Note: A producer is an operation that produces one of our inputs.
82 // changed |= tryMergedIslandsIntoNearestProducer();
83 // }
84 //
85 class CoarseningAnalysis {
86 public:
87 // Compute the coarsening analysis over the given graph.
88 explicit CoarseningAnalysis(GraphOp graph);
89
90 // Returns a list of all of the mergable islands found in the graph.
91 iterator_range<
92 llvm::filter_iterator<SmallVector<MergedIsland>::const_iterator,
93 function_ref<bool(const MergedIsland&)>>>
GetMergableIslands() const94 GetMergableIslands() const {
95 function_ref<bool(const MergedIsland&)> filter_fn =
96 [](const MergedIsland& merged_island) {
97 return merged_island.islands.size() > 1;
98 };
99 return llvm::make_filter_range(merged_islands_, filter_fn);
100 }
101
102 private:
103 // Attempt to find an island group that produces a value consumed by one of
104 // the islands (or operation therein) within the given `merged_island`. If no
105 // candidate can be found, returns nullptr.
106 MergedIsland* GetOperandCandidateToMergeWith(GraphOp graph,
107 MergedIsland& merged_island);
108
109 // Attempt to find an island group that consumes a result, either control or
110 // data, from one of the islands in the given `merged_island`. If no candidate
111 // can be found, returns nullptr.
112 MergedIsland* GetResultCandidateToMergeWith(GraphOp graph,
113 MergedIsland& merged_island);
114
115 // All of the merged islands in the graph.
116 SmallVector<MergedIsland> merged_islands_;
117 // A mapping from an island operation to the current merged island group it
118 // is a part of.
119 DenseMap<Operation*, MergedIsland*> island_to_merged_island_;
120 };
121
CoarseningAnalysis(GraphOp graph)122 CoarseningAnalysis::CoarseningAnalysis(GraphOp graph) {
123 // As an initial step, construct a merged island for each island in the
124 // graph.
125 for (IslandOp island : graph.getBody()->getOps<IslandOp>())
126 merged_islands_.push_back(MergedIsland(island));
127
128 // Record the mapping from the island to the merge group as a secondary step,
129 // as we are taking the address of the islands here and the push_back step
130 // above may invalidate previously inserted islands mid-loop.
131 for (MergedIsland& island : merged_islands_)
132 island_to_merged_island_.try_emplace(island.insert_point, &island);
133
134 // This functor merges the given `old_merged_island` into the
135 // `new_merged_island`. `merge_in_front` is whether the old island should be
136 // merged into the front of the new island, or the back.
137 auto merge_islands = [&](MergedIsland& old_merged_island,
138 MergedIsland& new_merged_island,
139 bool merge_in_front) {
140 for (IslandOp island : old_merged_island.islands)
141 island_to_merged_island_[island] = &new_merged_island;
142
143 auto insert_point = merge_in_front ? new_merged_island.islands.begin()
144 : new_merged_island.islands.end();
145 new_merged_island.islands.insert(insert_point,
146 old_merged_island.islands.begin(),
147 old_merged_island.islands.end());
148 old_merged_island.islands.clear();
149 };
150
151 // Iterate over all of the island groups attempting to merge as many islands
152 // groups as possible.
153 bool updated = false;
154 do {
155 updated = false;
156
157 // Attempt to merge an island into an island consuming one of its results.
158 for (MergedIsland& merged_island : llvm::reverse(merged_islands_)) {
159 if (merged_island.islands.empty()) continue;
160
161 MergedIsland* candidate =
162 GetResultCandidateToMergeWith(graph, merged_island);
163 if (candidate) {
164 merge_islands(merged_island, *candidate, /*merge_in_front=*/true);
165 updated = true;
166 }
167 }
168
169 // Attempt to merge an island into an island producing one of its operands.
170 for (MergedIsland& merged_island : merged_islands_) {
171 if (merged_island.islands.empty()) continue;
172
173 MergedIsland* candidate =
174 GetOperandCandidateToMergeWith(graph, merged_island);
175 if (candidate) {
176 merge_islands(merged_island, *candidate, /*merge_in_front=*/false);
177 updated = true;
178 }
179 }
180 } while (updated);
181 }
182
GetOperandCandidateToMergeWith(GraphOp graph,MergedIsland & merged_island)183 MergedIsland* CoarseningAnalysis::GetOperandCandidateToMergeWith(
184 GraphOp graph, MergedIsland& merged_island) {
185 // The candidate operation to consider merging the current island group with.
186 Operation* candidate = nullptr;
187 // The island group of the current candidate if it is an IslandOp, nullptr
188 // otherwise.
189 MergedIsland* candidate_island = nullptr;
190
191 // Given an input operation, try to replace the current candidate operation
192 // with it.
193 auto try_update_current_candidate = [&](Operation* rhs) {
194 MergedIsland* rhs_island = nullptr;
195 // Check if this is an island operation we can merge with.
196 auto rhs_it = island_to_merged_island_.find(rhs);
197 if (rhs_it != island_to_merged_island_.end()) {
198 rhs_island = rhs_it->second;
199
200 // Ignore islands that are already a part of the current island group.
201 if (rhs_island == &merged_island) return;
202
203 rhs = rhs_island->insert_point;
204 }
205 if (!candidate || candidate->isBeforeInBlock(rhs)) {
206 candidate = rhs;
207 candidate_island = rhs_island;
208 }
209 };
210
211 // Check island control operands.
212 for (IslandOp island : merged_island.islands) {
213 for (Value input : island.controlInputs()) {
214 Operation* def = input.getDefiningOp();
215 DCHECK_EQ(def->getParentOp(), graph);
216 try_update_current_candidate(def);
217 }
218
219 // Check island data operands.
220 island.walk([&](Operation* op) {
221 for (Value input : op->getOperands()) {
222 Operation* def = input.getDefiningOp();
223 if (!def || def->getParentOp() != graph) continue;
224
225 try_update_current_candidate(def);
226 }
227 });
228 }
229
230 return candidate_island;
231 }
232
GetResultCandidateToMergeWith(GraphOp graph,MergedIsland & merged_island)233 MergedIsland* CoarseningAnalysis::GetResultCandidateToMergeWith(
234 GraphOp graph, MergedIsland& merged_island) {
235 // The candidate operation to consider merging the current island group with.
236 Operation* candidate = nullptr;
237 // The island group of the current candidate if it is an IslandOp, nullptr
238 // otherwise.
239 MergedIsland* candidate_island = nullptr;
240
241 // Given an input operation, try to replace the current candidate operation
242 // with it.
243 auto try_update_current_candidate = [&](Operation* rhs) {
244 MergedIsland* rhs_island = nullptr;
245
246 // Check if this is an island operation we can merge with.
247 auto rhs_it = island_to_merged_island_.find(rhs);
248 if (rhs_it != island_to_merged_island_.end()) {
249 rhs_island = rhs_it->second;
250
251 // Ignore islands that are already a part of the current island group.
252 if (rhs_island == &merged_island) return;
253
254 rhs = rhs_island->insert_point;
255 }
256 if (!candidate || rhs->isBeforeInBlock(candidate)) {
257 candidate = rhs;
258 candidate_island = rhs_island;
259 }
260 };
261
262 // Check island control results.
263 for (IslandOp island : merged_island.islands) {
264 for (Operation* user : island.control().getUsers()) {
265 DCHECK_EQ(user->getParentOp(), graph);
266 try_update_current_candidate(user);
267 }
268
269 // Check island data results.
270 Block& graph_body = llvm::cast<GraphOp>(graph).GetBody();
271 for (Value result : island.outputs()) {
272 for (Operation* user : result.getUsers()) {
273 Operation* def = graph_body.findAncestorOpInBlock(*user);
274 DCHECK_NE(def, nullptr);
275 try_update_current_candidate(def);
276 }
277 }
278 }
279
280 return candidate_island;
281 }
282
283 //===----------------------------------------------------------------------===//
284 // Transformation
285 //===----------------------------------------------------------------------===//
286
287 // IslandResult is a helper struct holding an islands result and associated
288 // inner op result.
289 struct IslandResult {
IslandResultmlir::tf_executor::__anonc689b1ec0111::IslandResult290 IslandResult(Value inner_op_result, Value island_result)
291 : inner_op_result(inner_op_result), island_result(island_result) {}
292
293 Value inner_op_result;
294 Value island_result;
295 };
296
297 // This structure is used to gather the new operands and result of an island
298 // during merging.
299 struct IslandOperandsAndResults {
300 llvm::SmallSetVector<Value, 8> operands;
301 llvm::SmallVector<IslandResult> results;
302 };
303
304 // Collects the results for the new island by going through each data result of
305 // the islands being merged. Unused results outside of the merged island to be
306 // formed are pruned. If the child island inner ops consume the parent island
307 // control result, the child island inner ops will have that respective control
308 // input pruned. Results of the parent island that are consumed by the child
309 // island are replaced by the respective inner ops result from the parent
310 // island.
GetNewIslandResultsAndForwardResults(const MergedIsland & merged_island,llvm::SmallVector<IslandResult> & results)311 void GetNewIslandResultsAndForwardResults(
312 const MergedIsland& merged_island,
313 llvm::SmallVector<IslandResult>& results) {
314 results.clear();
315
316 // Collect all of the blocks within each of the island operations, these will
317 // be used to detect when an operation has a use within one of the merged
318 // islands.
319 llvm::SmallPtrSet<Block*, 8> islandBlocks;
320 for (IslandOp island : merged_island.islands)
321 island->walk([&](Block* block) { islandBlocks.insert(block); });
322
323 for (IslandOp island : merged_island.islands) {
324 for (auto ret_vals :
325 llvm::zip(island.GetYield().getOperands(), island.outputs())) {
326 bool result_captured = false;
327 Value inner_op_result = std::get<0>(ret_vals);
328 Value island_result = std::get<1>(ret_vals);
329 for (auto& use : llvm::make_early_inc_range(island_result.getUses())) {
330 if (islandBlocks.count(use.getOwner()->getBlock())) {
331 // If the use is within our island group, forward the result from
332 // inner op.
333 use.set(inner_op_result);
334 } else if (!result_captured) {
335 results.emplace_back(inner_op_result, island_result);
336 result_captured = true;
337 }
338 }
339 }
340 }
341 }
342
343 // Creates the new merged island.
CreateNewIsland(const MergedIsland & merged_island,llvm::ArrayRef<Value> operands,llvm::ArrayRef<IslandResult> results)344 IslandOp CreateNewIsland(const MergedIsland& merged_island,
345 llvm::ArrayRef<Value> operands,
346 llvm::ArrayRef<IslandResult> results) {
347 // Collect types from results.
348 llvm::SmallVector<Type, 8> result_types;
349 result_types.reserve(results.size());
350 for (const auto& result : results)
351 result_types.push_back(result.inner_op_result.getType());
352
353 // IslandOps always have a control result.
354 result_types.push_back(
355 ControlType::get(merged_island.insert_point->getContext()));
356
357 OpBuilder builder(merged_island.insert_point);
358 auto new_island = builder.create<IslandOp>(
359 merged_island.insert_point->getLoc(), result_types, operands);
360 new_island.body().push_back(new Block);
361 return new_island;
362 }
363
364 // Creates respective YieldOp for the new merged island.
CreateNewIslandYieldOp(IslandOp new_island,llvm::ArrayRef<IslandResult> results)365 YieldOp CreateNewIslandYieldOp(IslandOp new_island,
366 llvm::ArrayRef<IslandResult> results) {
367 llvm::SmallVector<Value, 8> yield_operands;
368 yield_operands.reserve(results.size());
369
370 for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
371 const auto& old_result = std::get<0>(ret_vals);
372
373 // Replace original island result with new island result.
374 old_result.island_result.replaceAllUsesWith(std::get<1>(ret_vals));
375
376 // Add associated inner op result to operands of the YieldOp.
377 yield_operands.push_back(old_result.inner_op_result);
378 }
379
380 // Create YieldOp for the new island.
381 OpBuilder builder(&new_island.GetBody(), new_island.GetBody().end());
382 return builder.create<YieldOp>(new_island.getLoc(), yield_operands);
383 }
384
385 // Moves inner ops (excluding last op/YieldOp) from islands being merged into
386 // the new merged island.
MoveInnerOpsToNewIsland(const MergedIsland & merged_island,Operation * new_yield_op)387 void MoveInnerOpsToNewIsland(const MergedIsland& merged_island,
388 Operation* new_yield_op) {
389 Block* block = new_yield_op->getBlock();
390
391 auto move_inner_ops = [block, new_yield_op](IslandOp island) {
392 auto& island_body = island.GetBody().getOperations();
393 block->getOperations().splice(new_yield_op->getIterator(), island_body,
394 island_body.begin(),
395 std::prev(island_body.end()));
396 };
397 for (IslandOp island : merged_island.islands) move_inner_ops(island);
398 }
399
400 // Merges the islands within the given island group.
401 // `island_operands_and_results` is passed in as scrach storage for the duration
402 // of this function.
MergeIslands(const MergedIsland & merged_island,IslandOperandsAndResults & island_operands_and_results)403 void MergeIslands(const MergedIsland& merged_island,
404 IslandOperandsAndResults& island_operands_and_results) {
405 // Collect operands for the new merged island.
406 island_operands_and_results.operands.clear();
407 for (IslandOp island : merged_island.islands)
408 island_operands_and_results.operands.insert(island.operand_begin(),
409 island.operand_end());
410 for (IslandOp island : merged_island.islands)
411 island_operands_and_results.operands.remove(island.control());
412
413 // Collect results for the new merged island.
414 GetNewIslandResultsAndForwardResults(merged_island,
415 island_operands_and_results.results);
416
417 // Create the new merged island.
418 IslandOp new_island = CreateNewIsland(
419 merged_island, island_operands_and_results.operands.getArrayRef(),
420 island_operands_and_results.results);
421
422 // Create associated YieldOp for the new merged island.
423 YieldOp new_yield_op =
424 CreateNewIslandYieldOp(new_island, island_operands_and_results.results);
425
426 // Move inner ops from original islands into the new island.
427 MoveInnerOpsToNewIsland(merged_island, new_yield_op.getOperation());
428
429 // Update control inputs to point to the new merged island.
430 for (IslandOp island : merged_island.islands)
431 island.control().replaceAllUsesWith(new_island.control());
432 for (IslandOp island : merged_island.islands) island->erase();
433 }
434
435 // Takes the inputs to tf_executor.fetch, make a new island that just yields
436 // them, and replace the fetch's input operands with the new yielded values.
437 //
438 // This allows our def-use based island coarsening algorithm to merge
439 // islands that independently feed into a fetch.
InsertDummyIslandForFetch(FetchOp fetch)440 void InsertDummyIslandForFetch(FetchOp fetch) {
441 llvm::SmallVector<Value, 4> data_fetches;
442 llvm::SmallVector<Type, 4> data_types;
443 llvm::SmallVector<Value, 4> control_fetches;
444 data_fetches.reserve(fetch.fetches().size());
445 data_types.reserve(data_fetches.capacity());
446 control_fetches.reserve(data_fetches.capacity());
447
448 for (auto value : fetch.fetches()) {
449 if (value.getType().isa<ControlType>()) {
450 control_fetches.push_back(value);
451 } else {
452 data_fetches.push_back(value);
453 data_types.push_back(value.getType());
454 }
455 }
456 auto island = OpBuilder(fetch).create<IslandOp>(
457 fetch.getLoc(), data_types,
458 /*control=*/ControlType::get(fetch.getContext()),
459 /*controlInputs=*/control_fetches);
460 island.body().push_back(new Block);
461 OpBuilder::atBlockEnd(&island.GetBody())
462 .create<YieldOp>(fetch.getLoc(), data_fetches);
463 const int fetch_control_idx = data_fetches.size();
464 for (int i = 0, e = fetch.getNumOperands(); i < e; i++) {
465 // The fetch could have multiple control operands (all at the end of its
466 // operand list). We replace them all with the island's single control
467 // operand.
468 if (i <= fetch_control_idx) {
469 fetch.setOperand(i, island.getResult(i));
470 } else {
471 fetch.getOperation()->eraseOperand(fetch.getNumOperands() - 1);
472 }
473 }
474 }
475
476 //===----------------------------------------------------------------------===//
477 // Pass Entry Point
478 //===----------------------------------------------------------------------===//
479
480 struct ExecutorIslandCoarseningPass
481 : public TF::ExecutorIslandCoarseningPassBase<
482 ExecutorIslandCoarseningPass> {
483 void runOnOperation() override;
484 };
485
runOnOperation()486 void ExecutorIslandCoarseningPass::runOnOperation() {
487 // Temporary datastructure to keep operands and results for each island.
488 // We define it here to grow and reuse the storage for the duration of the
489 // pass.
490 IslandOperandsAndResults island_operands_and_results;
491
492 getOperation().walk([&](GraphOp graph) {
493 InsertDummyIslandForFetch(graph.GetFetch());
494
495 // Compute an analysis that decides which islands should be merged together,
496 // and merge any island groups it finds.
497 CoarseningAnalysis analysis(graph);
498 for (const MergedIsland& island : analysis.GetMergableIslands())
499 MergeIslands(island, island_operands_and_results);
500 });
501 }
502
503 } // namespace
504
505 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTFExecutorIslandCoarseningPass()506 CreateTFExecutorIslandCoarseningPass() {
507 return std::make_unique<ExecutorIslandCoarseningPass>();
508 }
509
510 } // namespace tf_executor
511 } // namespace mlir
512