• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass takes TensorFlow executor dialect IslandOps and
17 // merges them. Note, this currently does not handle TensorFlow V1 style control
18 // flow/frames or side effecting ops yet.
19 
20 #include <iterator>
21 #include <tuple>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/None.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/IR/Block.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/Operation.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
38 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
39 #include "tensorflow/core/platform/logging.h"
40 
41 namespace mlir {
42 namespace tf_executor {
43 
44 namespace {
45 
46 //===----------------------------------------------------------------------===//
47 // Analysis
48 //===----------------------------------------------------------------------===//
49 
50 // This structure represents a merged island. It includes all of the islands
51 // that can be merged together and the point of insertion of the merged island.
52 struct MergedIsland {
53   // Construct a new island from the given root.
MergedIslandmlir::tf_executor::__anonc689b1ec0111::MergedIsland54   explicit MergedIsland(IslandOp root) : insert_point(root) {
55     islands.push_back(root);
56   }
57 
58   // The insertion point anchor of the merged island, or where the merged island
59   // will be inserted when created.
60   Operation* const insert_point;
61 
62   // The set of islands that are to be merged together.
63   SmallVector<IslandOp> islands;
64 };
65 
66 // This structure contains all of the merge decisions for islands within a
67 // graph. We compute which islands to merge first, so that we don't need to
68 // recursively mutate the IR (resulting in quadratic behavior when moving
69 // operations). A rough sketch of the coarsening algorithm is shown below:
70 //
71 // // The algorithm iterates until a fixpoint is reached, i.e. when no more
72 // // islands can be merged.
73 // while (changed) {
74 //   // In the first phase we try to merge islands with their nearest consumer
75 //   // iff the consumer is another island.
76 //   // Note: A consumer is an operation that consumes one of our outputs.
77 //   changed |= tryMergedIslandsIntoNearestConsumer();
78 //
79 //   // In the second phase we try to merge islands with their nearest producer
80 //   // of a value they consume, iff the producer is another island.
81 //   // Note: A producer is an operation that produces one of our inputs.
82 //   changed |= tryMergedIslandsIntoNearestProducer();
83 // }
84 //
85 class CoarseningAnalysis {
86  public:
87   // Compute the coarsening analysis over the given graph.
88   explicit CoarseningAnalysis(GraphOp graph);
89 
90   // Returns a list of all of the mergable islands found in the graph.
91   iterator_range<
92       llvm::filter_iterator<SmallVector<MergedIsland>::const_iterator,
93                             function_ref<bool(const MergedIsland&)>>>
GetMergableIslands() const94   GetMergableIslands() const {
95     function_ref<bool(const MergedIsland&)> filter_fn =
96         [](const MergedIsland& merged_island) {
97           return merged_island.islands.size() > 1;
98         };
99     return llvm::make_filter_range(merged_islands_, filter_fn);
100   }
101 
102  private:
103   // Attempt to find an island group that produces a value consumed by one of
104   // the islands (or operation therein) within the given `merged_island`. If no
105   // candidate can be found, returns nullptr.
106   MergedIsland* GetOperandCandidateToMergeWith(GraphOp graph,
107                                                MergedIsland& merged_island);
108 
109   // Attempt to find an island group that consumes a result, either control or
110   // data, from one of the islands in the given `merged_island`. If no candidate
111   // can be found, returns nullptr.
112   MergedIsland* GetResultCandidateToMergeWith(GraphOp graph,
113                                               MergedIsland& merged_island);
114 
115   // All of the merged islands in the graph.
116   SmallVector<MergedIsland> merged_islands_;
117   // A mapping from an island operation to the current merged island group it
118   // is a part of.
119   DenseMap<Operation*, MergedIsland*> island_to_merged_island_;
120 };
121 
CoarseningAnalysis(GraphOp graph)122 CoarseningAnalysis::CoarseningAnalysis(GraphOp graph) {
123   // As an initial step, construct a merged island for each island in the
124   // graph.
125   for (IslandOp island : graph.getBody()->getOps<IslandOp>())
126     merged_islands_.push_back(MergedIsland(island));
127 
128   // Record the mapping from the island to the merge group as a secondary step,
129   // as we are taking the address of the islands here and the push_back step
130   // above may invalidate previously inserted islands mid-loop.
131   for (MergedIsland& island : merged_islands_)
132     island_to_merged_island_.try_emplace(island.insert_point, &island);
133 
134   // This functor merges the given `old_merged_island` into the
135   // `new_merged_island`. `merge_in_front` is whether the old island should be
136   // merged into the front of the new island, or the back.
137   auto merge_islands = [&](MergedIsland& old_merged_island,
138                            MergedIsland& new_merged_island,
139                            bool merge_in_front) {
140     for (IslandOp island : old_merged_island.islands)
141       island_to_merged_island_[island] = &new_merged_island;
142 
143     auto insert_point = merge_in_front ? new_merged_island.islands.begin()
144                                        : new_merged_island.islands.end();
145     new_merged_island.islands.insert(insert_point,
146                                      old_merged_island.islands.begin(),
147                                      old_merged_island.islands.end());
148     old_merged_island.islands.clear();
149   };
150 
151   // Iterate over all of the island groups attempting to merge as many islands
152   // groups as possible.
153   bool updated = false;
154   do {
155     updated = false;
156 
157     // Attempt to merge an island into an island consuming one of its results.
158     for (MergedIsland& merged_island : llvm::reverse(merged_islands_)) {
159       if (merged_island.islands.empty()) continue;
160 
161       MergedIsland* candidate =
162           GetResultCandidateToMergeWith(graph, merged_island);
163       if (candidate) {
164         merge_islands(merged_island, *candidate, /*merge_in_front=*/true);
165         updated = true;
166       }
167     }
168 
169     // Attempt to merge an island into an island producing one of its operands.
170     for (MergedIsland& merged_island : merged_islands_) {
171       if (merged_island.islands.empty()) continue;
172 
173       MergedIsland* candidate =
174           GetOperandCandidateToMergeWith(graph, merged_island);
175       if (candidate) {
176         merge_islands(merged_island, *candidate, /*merge_in_front=*/false);
177         updated = true;
178       }
179     }
180   } while (updated);
181 }
182 
GetOperandCandidateToMergeWith(GraphOp graph,MergedIsland & merged_island)183 MergedIsland* CoarseningAnalysis::GetOperandCandidateToMergeWith(
184     GraphOp graph, MergedIsland& merged_island) {
185   // The candidate operation to consider merging the current island group with.
186   Operation* candidate = nullptr;
187   // The island group of the current candidate if it is an IslandOp, nullptr
188   // otherwise.
189   MergedIsland* candidate_island = nullptr;
190 
191   // Given an input operation, try to replace the current candidate operation
192   // with it.
193   auto try_update_current_candidate = [&](Operation* rhs) {
194     MergedIsland* rhs_island = nullptr;
195     // Check if this is an island operation we can merge with.
196     auto rhs_it = island_to_merged_island_.find(rhs);
197     if (rhs_it != island_to_merged_island_.end()) {
198       rhs_island = rhs_it->second;
199 
200       // Ignore islands that are already a part of the current island group.
201       if (rhs_island == &merged_island) return;
202 
203       rhs = rhs_island->insert_point;
204     }
205     if (!candidate || candidate->isBeforeInBlock(rhs)) {
206       candidate = rhs;
207       candidate_island = rhs_island;
208     }
209   };
210 
211   // Check island control operands.
212   for (IslandOp island : merged_island.islands) {
213     for (Value input : island.controlInputs()) {
214       Operation* def = input.getDefiningOp();
215       DCHECK_EQ(def->getParentOp(), graph);
216       try_update_current_candidate(def);
217     }
218 
219     // Check island data operands.
220     island.walk([&](Operation* op) {
221       for (Value input : op->getOperands()) {
222         Operation* def = input.getDefiningOp();
223         if (!def || def->getParentOp() != graph) continue;
224 
225         try_update_current_candidate(def);
226       }
227     });
228   }
229 
230   return candidate_island;
231 }
232 
GetResultCandidateToMergeWith(GraphOp graph,MergedIsland & merged_island)233 MergedIsland* CoarseningAnalysis::GetResultCandidateToMergeWith(
234     GraphOp graph, MergedIsland& merged_island) {
235   // The candidate operation to consider merging the current island group with.
236   Operation* candidate = nullptr;
237   // The island group of the current candidate if it is an IslandOp, nullptr
238   // otherwise.
239   MergedIsland* candidate_island = nullptr;
240 
241   // Given an input operation, try to replace the current candidate operation
242   // with it.
243   auto try_update_current_candidate = [&](Operation* rhs) {
244     MergedIsland* rhs_island = nullptr;
245 
246     // Check if this is an island operation we can merge with.
247     auto rhs_it = island_to_merged_island_.find(rhs);
248     if (rhs_it != island_to_merged_island_.end()) {
249       rhs_island = rhs_it->second;
250 
251       // Ignore islands that are already a part of the current island group.
252       if (rhs_island == &merged_island) return;
253 
254       rhs = rhs_island->insert_point;
255     }
256     if (!candidate || rhs->isBeforeInBlock(candidate)) {
257       candidate = rhs;
258       candidate_island = rhs_island;
259     }
260   };
261 
262   // Check island control results.
263   for (IslandOp island : merged_island.islands) {
264     for (Operation* user : island.control().getUsers()) {
265       DCHECK_EQ(user->getParentOp(), graph);
266       try_update_current_candidate(user);
267     }
268 
269     // Check island data results.
270     Block& graph_body = llvm::cast<GraphOp>(graph).GetBody();
271     for (Value result : island.outputs()) {
272       for (Operation* user : result.getUsers()) {
273         Operation* def = graph_body.findAncestorOpInBlock(*user);
274         DCHECK_NE(def, nullptr);
275         try_update_current_candidate(def);
276       }
277     }
278   }
279 
280   return candidate_island;
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // Transformation
285 //===----------------------------------------------------------------------===//
286 
287 // IslandResult is a helper struct holding an islands result and associated
288 // inner op result.
289 struct IslandResult {
IslandResultmlir::tf_executor::__anonc689b1ec0111::IslandResult290   IslandResult(Value inner_op_result, Value island_result)
291       : inner_op_result(inner_op_result), island_result(island_result) {}
292 
293   Value inner_op_result;
294   Value island_result;
295 };
296 
297 // This structure is used to gather the new operands and result of an island
298 // during merging.
299 struct IslandOperandsAndResults {
300   llvm::SmallSetVector<Value, 8> operands;
301   llvm::SmallVector<IslandResult> results;
302 };
303 
304 // Collects the results for the new island by going through each data result of
305 // the islands being merged. Unused results outside of the merged island to be
306 // formed are pruned. If the child island inner ops consume the parent island
307 // control result, the child island inner ops will have that respective control
308 // input pruned. Results of the parent island that are consumed by the child
309 // island are replaced by the respective inner ops result from the parent
310 // island.
GetNewIslandResultsAndForwardResults(const MergedIsland & merged_island,llvm::SmallVector<IslandResult> & results)311 void GetNewIslandResultsAndForwardResults(
312     const MergedIsland& merged_island,
313     llvm::SmallVector<IslandResult>& results) {
314   results.clear();
315 
316   // Collect all of the blocks within each of the island operations, these will
317   // be used to detect when an operation has a use within one of the merged
318   // islands.
319   llvm::SmallPtrSet<Block*, 8> islandBlocks;
320   for (IslandOp island : merged_island.islands)
321     island->walk([&](Block* block) { islandBlocks.insert(block); });
322 
323   for (IslandOp island : merged_island.islands) {
324     for (auto ret_vals :
325          llvm::zip(island.GetYield().getOperands(), island.outputs())) {
326       bool result_captured = false;
327       Value inner_op_result = std::get<0>(ret_vals);
328       Value island_result = std::get<1>(ret_vals);
329       for (auto& use : llvm::make_early_inc_range(island_result.getUses())) {
330         if (islandBlocks.count(use.getOwner()->getBlock())) {
331           // If the use is within our island group, forward the result from
332           // inner op.
333           use.set(inner_op_result);
334         } else if (!result_captured) {
335           results.emplace_back(inner_op_result, island_result);
336           result_captured = true;
337         }
338       }
339     }
340   }
341 }
342 
343 // Creates the new merged island.
CreateNewIsland(const MergedIsland & merged_island,llvm::ArrayRef<Value> operands,llvm::ArrayRef<IslandResult> results)344 IslandOp CreateNewIsland(const MergedIsland& merged_island,
345                          llvm::ArrayRef<Value> operands,
346                          llvm::ArrayRef<IslandResult> results) {
347   // Collect types from results.
348   llvm::SmallVector<Type, 8> result_types;
349   result_types.reserve(results.size());
350   for (const auto& result : results)
351     result_types.push_back(result.inner_op_result.getType());
352 
353   // IslandOps always have a control result.
354   result_types.push_back(
355       ControlType::get(merged_island.insert_point->getContext()));
356 
357   OpBuilder builder(merged_island.insert_point);
358   auto new_island = builder.create<IslandOp>(
359       merged_island.insert_point->getLoc(), result_types, operands);
360   new_island.body().push_back(new Block);
361   return new_island;
362 }
363 
364 // Creates respective YieldOp for the new merged island.
CreateNewIslandYieldOp(IslandOp new_island,llvm::ArrayRef<IslandResult> results)365 YieldOp CreateNewIslandYieldOp(IslandOp new_island,
366                                llvm::ArrayRef<IslandResult> results) {
367   llvm::SmallVector<Value, 8> yield_operands;
368   yield_operands.reserve(results.size());
369 
370   for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
371     const auto& old_result = std::get<0>(ret_vals);
372 
373     // Replace original island result with new island result.
374     old_result.island_result.replaceAllUsesWith(std::get<1>(ret_vals));
375 
376     // Add associated inner op result to operands of the YieldOp.
377     yield_operands.push_back(old_result.inner_op_result);
378   }
379 
380   // Create YieldOp for the new island.
381   OpBuilder builder(&new_island.GetBody(), new_island.GetBody().end());
382   return builder.create<YieldOp>(new_island.getLoc(), yield_operands);
383 }
384 
385 // Moves inner ops (excluding last op/YieldOp) from islands being merged into
386 // the new merged island.
MoveInnerOpsToNewIsland(const MergedIsland & merged_island,Operation * new_yield_op)387 void MoveInnerOpsToNewIsland(const MergedIsland& merged_island,
388                              Operation* new_yield_op) {
389   Block* block = new_yield_op->getBlock();
390 
391   auto move_inner_ops = [block, new_yield_op](IslandOp island) {
392     auto& island_body = island.GetBody().getOperations();
393     block->getOperations().splice(new_yield_op->getIterator(), island_body,
394                                   island_body.begin(),
395                                   std::prev(island_body.end()));
396   };
397   for (IslandOp island : merged_island.islands) move_inner_ops(island);
398 }
399 
400 // Merges the islands within the given island group.
401 // `island_operands_and_results` is passed in as scrach storage for the duration
402 // of this function.
MergeIslands(const MergedIsland & merged_island,IslandOperandsAndResults & island_operands_and_results)403 void MergeIslands(const MergedIsland& merged_island,
404                   IslandOperandsAndResults& island_operands_and_results) {
405   // Collect operands for the new merged island.
406   island_operands_and_results.operands.clear();
407   for (IslandOp island : merged_island.islands)
408     island_operands_and_results.operands.insert(island.operand_begin(),
409                                                 island.operand_end());
410   for (IslandOp island : merged_island.islands)
411     island_operands_and_results.operands.remove(island.control());
412 
413   // Collect results for the new merged island.
414   GetNewIslandResultsAndForwardResults(merged_island,
415                                        island_operands_and_results.results);
416 
417   // Create the new merged island.
418   IslandOp new_island = CreateNewIsland(
419       merged_island, island_operands_and_results.operands.getArrayRef(),
420       island_operands_and_results.results);
421 
422   // Create associated YieldOp for the new merged island.
423   YieldOp new_yield_op =
424       CreateNewIslandYieldOp(new_island, island_operands_and_results.results);
425 
426   // Move inner ops from original islands into the new island.
427   MoveInnerOpsToNewIsland(merged_island, new_yield_op.getOperation());
428 
429   // Update control inputs to point to the new merged island.
430   for (IslandOp island : merged_island.islands)
431     island.control().replaceAllUsesWith(new_island.control());
432   for (IslandOp island : merged_island.islands) island->erase();
433 }
434 
435 // Takes the inputs to tf_executor.fetch, make a new island that just yields
436 // them, and replace the fetch's input operands with the new yielded values.
437 //
438 // This allows our def-use based island coarsening algorithm to merge
439 // islands that independently feed into a fetch.
InsertDummyIslandForFetch(FetchOp fetch)440 void InsertDummyIslandForFetch(FetchOp fetch) {
441   llvm::SmallVector<Value, 4> data_fetches;
442   llvm::SmallVector<Type, 4> data_types;
443   llvm::SmallVector<Value, 4> control_fetches;
444   data_fetches.reserve(fetch.fetches().size());
445   data_types.reserve(data_fetches.capacity());
446   control_fetches.reserve(data_fetches.capacity());
447 
448   for (auto value : fetch.fetches()) {
449     if (value.getType().isa<ControlType>()) {
450       control_fetches.push_back(value);
451     } else {
452       data_fetches.push_back(value);
453       data_types.push_back(value.getType());
454     }
455   }
456   auto island = OpBuilder(fetch).create<IslandOp>(
457       fetch.getLoc(), data_types,
458       /*control=*/ControlType::get(fetch.getContext()),
459       /*controlInputs=*/control_fetches);
460   island.body().push_back(new Block);
461   OpBuilder::atBlockEnd(&island.GetBody())
462       .create<YieldOp>(fetch.getLoc(), data_fetches);
463   const int fetch_control_idx = data_fetches.size();
464   for (int i = 0, e = fetch.getNumOperands(); i < e; i++) {
465     // The fetch could have multiple control operands (all at the end of its
466     // operand list). We replace them all with the island's single control
467     // operand.
468     if (i <= fetch_control_idx) {
469       fetch.setOperand(i, island.getResult(i));
470     } else {
471       fetch.getOperation()->eraseOperand(fetch.getNumOperands() - 1);
472     }
473   }
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // Pass Entry Point
478 //===----------------------------------------------------------------------===//
479 
480 struct ExecutorIslandCoarseningPass
481     : public TF::ExecutorIslandCoarseningPassBase<
482           ExecutorIslandCoarseningPass> {
483   void runOnOperation() override;
484 };
485 
runOnOperation()486 void ExecutorIslandCoarseningPass::runOnOperation() {
487   // Temporary datastructure to keep operands and results for each island.
488   // We define it here to grow and reuse the storage for the duration of the
489   // pass.
490   IslandOperandsAndResults island_operands_and_results;
491 
492   getOperation().walk([&](GraphOp graph) {
493     InsertDummyIslandForFetch(graph.GetFetch());
494 
495     // Compute an analysis that decides which islands should be merged together,
496     // and merge any island groups it finds.
497     CoarseningAnalysis analysis(graph);
498     for (const MergedIsland& island : analysis.GetMergableIslands())
499       MergeIslands(island, island_operands_and_results);
500   });
501 }
502 
503 }  // namespace
504 
505 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTFExecutorIslandCoarseningPass()506 CreateTFExecutorIslandCoarseningPass() {
507   return std::make_unique<ExecutorIslandCoarseningPass>();
508 }
509 
510 }  // namespace tf_executor
511 }  // namespace mlir
512