• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass takes TensorFlow executor dialect IslandOps and
17 // merges the one that contains operation marked to run on TPU.
18 
19 #include <algorithm>
20 #include <iterator>
21 #include <queue>
22 #include <tuple>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Debug.h"
34 #include "mlir/IR/Attributes.h"  // from @llvm-project
35 #include "mlir/IR/Block.h"  // from @llvm-project
36 #include "mlir/IR/Builders.h"  // from @llvm-project
37 #include "mlir/IR/Location.h"  // from @llvm-project
38 #include "mlir/IR/Operation.h"  // from @llvm-project
39 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
40 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
41 #include "mlir/IR/Visitors.h"  // from @llvm-project
42 #include "mlir/Pass/Pass.h"  // from @llvm-project
43 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
44 #include "mlir/Support/LLVM.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
48 #include "tensorflow/core/platform/logging.h"
49 
50 #define DEBUG_TYPE "tf-executor-tpu-v1-island-coarsening"
51 
52 namespace mlir {
53 namespace tf_executor {
54 
55 namespace {
56 
57 constexpr llvm::StringRef kTpuReplicateAttr = "_tpu_replicate";
58 constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
59 
60 // This pass is a variant of the island coarsening that is limited to
61 // TPU-annotated operations and intended to preserve backward compatibility with
62 // TFv1.
63 struct TpuV1BridgeExecutorIslandCoarsening
64     : public PassWrapper<TpuV1BridgeExecutorIslandCoarsening,
65                          OperationPass<ModuleOp>> {
getArgumentmlir::tf_executor::__anon70a6e7720111::TpuV1BridgeExecutorIslandCoarsening66   StringRef getArgument() const final {
67     return "tf-executor-tpu-v1-island-coarsening";
68   }
69 
getDescriptionmlir::tf_executor::__anon70a6e7720111::TpuV1BridgeExecutorIslandCoarsening70   StringRef getDescription() const final {
71     return "Merges TPU clusters IslandOps, intended for V1 compatibility mode";
72   }
73 
74   void runOnOperation() override;
75 };
76 
77 // Sorts the operations in the provided range to enforce dominance.
78 // This is useful after fusing / reorganizing Operations in a block and later
79 // needing to readjust the ordering to ensure dominance.
SortTopologically(Block::iterator begin,Block::iterator end)80 LogicalResult SortTopologically(Block::iterator begin, Block::iterator end) {
81   Block* block = begin->getBlock();
82   // Either sort from `begin` to end of block or both `begin` and
83   // `end` should belong to the same block.
84   assert(end == block->end() ||
85          end->getBlock() == block && "ops must be in the same block");
86 
87   // Track the ops that still need to be scheduled in a set.
88   SmallPtrSet<Operation*, 16> unscheduled_ops;
89   for (Operation& op : llvm::make_range(begin, end))
90     unscheduled_ops.insert(&op);
91 
92   Block::iterator last_scheduled_op = begin;
93   while (!unscheduled_ops.empty()) {
94     bool scheduled_at_least_once = false;
95     // Loop over the ops that are not sorted yet, try to find the ones "ready",
96     // i.e. the ones for which there aren't any operand produced by an op in the
97     // set, and "schedule" it (move it before the last_scheduled_op).
98     for (Operation& op : llvm::make_range(last_scheduled_op, end)) {
99       WalkResult ready_to_schedule = op.walk([&](Operation* nested_op) {
100         for (Value operand : nested_op->getOperands()) {
101           Operation* defining_op = operand.getDefiningOp();
102           if (!defining_op) continue;
103           Operation* producer_in_block =
104               block->findAncestorOpInBlock(*defining_op);
105           if (producer_in_block && producer_in_block != &op &&
106               unscheduled_ops.count(producer_in_block)) {
107             // Found an operand that isn't scheduled yet, interrupt the walk.
108             return WalkResult::interrupt();
109           }
110         }
111         return WalkResult::advance();
112       });
113       if (ready_to_schedule.wasInterrupted()) continue;
114       unscheduled_ops.erase(&op);
115       if (Block::iterator(op) != last_scheduled_op)
116         op.moveBefore(block, last_scheduled_op);
117       else
118         ++last_scheduled_op;
119       scheduled_at_least_once = true;
120     }
121     if (!scheduled_at_least_once) return failure();
122   }
123   return success();
124 }
125 
126 // Looks for an IslandOp that wraps a single operation tagged with the
127 // _tpu_replicate attribute, and merges it with all the following operations in
128 // the block. Sets the `changed` boolean to true if any island is merged.
129 // Returns a failure if a cycle prevents the merge from happening correctly
130 // without breaking dominance. The IR is left in invalid state in case of
131 // failure.
MergeIsland(llvm::function_ref<bool (StringAttr,Operation *)> is_op_calling_func_for_cluster,Operation * op,bool * changed)132 LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
133                               is_op_calling_func_for_cluster,
134                           Operation* op, bool* changed) {
135   // Find the first island wrapping a single operation with the `_tpu_replicate`
136   // attribute, it'll be used as the root of the algorithm to find the other
137   // operations that are part of the same cluster.
138   IslandOp island = dyn_cast<IslandOp>(*op);
139   if (!island || !island.WrapsSingleOp()) return success();
140   Operation& wrapped_op = island.GetBody().front();
141 
142   // TODO(b/188046643): Conservatively fail until pass is extended to fuse
143   // chains of these ops.
144   if (isa<TF::TPUPartitionedInputOp, TF::TPUPartitionedOutputOp>(wrapped_op)) {
145     return failure();
146   }
147 
148   StringAttr cluster_name =
149       wrapped_op.getAttrOfType<StringAttr>(kTpuReplicateAttr);
150   if (!cluster_name)
151     cluster_name = wrapped_op.getAttrOfType<StringAttr>(kTpuStatusAttr);
152   if (!cluster_name) return success();
153 
154   // We found a _tpu_replicate, let's build an island for the full cluster!
155   LLVM_DEBUG(llvm::dbgs() << "Processing candidate island: "
156                           << *island.getOperation() << "\n");
157 
158   // Collect the islands to merge together in this new cluster starting with the
159   // given island.
160   SmallVector<IslandOp, 16> islands;
161   SmallPtrSet<Operation*, 16> wrapped_ops;
162   for (Operation& candidate_op : llvm::make_early_inc_range(
163            llvm::make_range(op->getIterator(), op->getBlock()->end()))) {
164     IslandOp candidate_island = dyn_cast<IslandOp>(candidate_op);
165     if (!candidate_island || !candidate_island.WrapsSingleOp()) continue;
166     // Check if we have an operation with the expected attribute.
167     Operation& candidate_wrapped_op = candidate_island.GetBody().front();
168 
169     // TODO(b/188046643): Conservatively fail until pass is extended to fuse
170     // chains of these ops.
171     if (isa<TF::TPUPartitionedInputOp, TF::TPUPartitionedOutputOp>(
172             candidate_wrapped_op)) {
173       return failure();
174     }
175 
176     StringAttr candidate_cluster_name =
177         candidate_wrapped_op.getAttrOfType<StringAttr>(kTpuReplicateAttr);
178     if (!candidate_cluster_name)
179       candidate_cluster_name =
180           candidate_wrapped_op.getAttrOfType<StringAttr>(kTpuStatusAttr);
181     if (candidate_cluster_name != cluster_name &&
182         !is_op_calling_func_for_cluster(cluster_name, &candidate_wrapped_op))
183       continue;
184 
185     // Look at captured operands to bring-in ReplicatedInputOp in the
186     // island as well. Consider pulling in tf.Const, some optimizations can
187     // benefit from this.
188     for (Value operand : candidate_wrapped_op.getOperands()) {
189       IslandOp wrapper = dyn_cast_or_null<IslandOp>(operand.getDefiningOp());
190       if (!wrapper || !wrapper.WrapsSingleOp()) continue;
191       Operation& wrapped_op = wrapper.GetBody().front();
192       if (!isa<TF::TPUReplicatedInputOp>(wrapped_op)) continue;
193       if (wrapped_ops.count(&wrapped_op)) continue;
194       wrapped_ops.insert(&wrapped_op);
195       islands.push_back(wrapper);
196     }
197     islands.push_back(candidate_island);
198     wrapped_ops.insert(&candidate_wrapped_op);
199 
200     // Look at results to bring-in ReplicatedOutputOp in the island as well.
201     for (Value result : candidate_island.getResults()) {
202       for (OpOperand use : result.getUsers()) {
203         Operation* user = use.getOwner();
204         if (!isa<TF::TPUReplicatedOutputOp>(user)) continue;
205         assert(!wrapped_ops.count(user) &&
206                "unexpected already processed TPUReplicatedOutputOp");
207         wrapped_ops.insert(user);
208         islands.push_back(cast<IslandOp>(user->getParentOp()));
209       }
210     }
211   }
212 
213   // If no other island was found to merge with the existing one, just
214   // move on.
215   if (islands.size() <= 1) return success();
216 
217   *changed = true;
218   auto first_op_after =
219       std::next(Block::iterator(islands.back().getOperation()));
220 
221   // Compute the result of the merged island, these are the values produced by
222   // the islands that are merged if they have a use in an island not merged,
223   // i.e. a value that escapes.
224   llvm::SmallVector<Type, 4> result_types;
225   for (IslandOp new_op : islands) {
226     for (Value result : new_op.outputs()) {
227       if (llvm::any_of(result.getUsers(), [&](OpOperand user) {
228             return !wrapped_ops.count(user.getOwner());
229           }))
230         result_types.push_back(result.getType());
231     }
232   }
233 
234   IslandOp new_island = OpBuilder(island).create<IslandOp>(
235       island.getLoc(), result_types,
236       /*control=*/ControlType::get(island.getContext()),
237       /*controlInputs=*/island.getOperands());
238   new_island.body().push_back(new Block);
239 
240   // Move the operations in the new island, gather the results of the new yield.
241   Block& island_body = new_island.GetBody();
242   SmallVector<Value, 16> yield_operands;
243   for (IslandOp island : islands) {
244     Operation& wrapped_op = island.GetBody().front();
245     wrapped_op.moveBefore(&island_body, island_body.end());
246 
247     // For every result of the wrapped_op, it needs to get passed to the yield
248     // operation, only if it escapes the island.
249     for (auto result : llvm::zip(island.outputs(), wrapped_op.getResults())) {
250       if (llvm::any_of(std::get<0>(result).getUsers(), [&](OpOperand user) {
251             return !wrapped_ops.count(user.getOwner());
252           }))
253         yield_operands.push_back(std::get<1>(result));
254     }
255   }
256   OpBuilder::atBlockEnd(&island_body)
257       .create<YieldOp>(new_island.getLoc(), yield_operands);
258 
259   // remap results of the new islands to the user outside of the island.
260   int current_result = 0;
261   Value control = new_island.control();
262   for (IslandOp island : islands) {
263     YieldOp yield_op = island.GetYield();
264     for (auto idx_result : llvm::enumerate(island.outputs())) {
265       Value result = idx_result.value();
266 
267       bool has_external_use = false;
268       for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
269         if (wrapped_ops.count(use.getOwner()))
270           use.set(yield_op.getOperand(idx_result.index()));
271         else
272           has_external_use = true;
273       }
274       if (has_external_use) {
275         result.replaceAllUsesWith(new_island.getResult(current_result));
276         ++current_result;
277       }
278     }
279     island.control().replaceAllUsesWith(control);
280     island.erase();
281   }
282 
283   // Ensure dominance by sorting the range of islands that were merged.
284   return SortTopologically(Block::iterator(new_island.getOperation()),
285                            first_op_after);
286 }
287 
288 // Returns all functions that can be reached from TPUPartitionedCall ops.
FindTPUPartitionedCallReachableFunctions(ModuleOp module)289 SmallPtrSet<Operation*, 16> FindTPUPartitionedCallReachableFunctions(
290     ModuleOp module) {
291   SymbolTableCollection table;
292   SymbolUserMap symbol_map(table, module);
293   llvm::DenseMap<FuncOp, llvm::DenseSet<FuncOp>> caller_callee_map;
294   // Creates work queue for determining reachability below.
295   std::queue<FuncOp> function_worklist;
296 
297   for (auto func : module.getOps<FuncOp>()) {
298     for (auto user : symbol_map.getUsers(func)) {
299       // Populates work queue with func ops called from TPUPartionedCall.
300       if (llvm::isa<TF::TPUPartitionedCallOp>(user)) {
301         function_worklist.push(func);
302       }
303       // Populates caller to called func map.
304       if (FuncOp caller = user->getParentOfType<FuncOp>()) {
305         caller_callee_map[caller].insert(func);
306       }
307     }
308   }
309 
310   // Determines reached ops starting from TPUPartionedCall ops
311   // and iteratively descending through called ops.
312   SmallPtrSet<Operation*, 16> reachable_functions;
313   while (!function_worklist.empty()) {
314     FuncOp caller = function_worklist.front();
315     function_worklist.pop();
316     if (reachable_functions.insert(caller).second) {
317       for (auto callee : caller_callee_map[caller]) {
318         function_worklist.push(callee);
319       }
320     }
321   }
322   return reachable_functions;
323 }
324 
runOnOperation()325 void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
326   SymbolTable symbol_table(getOperation());
327 
328   // Map tpu cluster names to the functions that contain operations for this
329   // cluster.
330   DenseMap<StringRef, DenseSet<FuncOp>> tpu_funcs;
331   for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
332     func_op.walk([&](Operation* op) {
333       StringAttr cluster_name =
334           op->getAttrOfType<StringAttr>(kTpuReplicateAttr);
335       if (!cluster_name)
336         cluster_name = op->getAttrOfType<StringAttr>(kTpuStatusAttr);
337       if (!cluster_name) return;
338       tpu_funcs[cluster_name.getValue()].insert(func_op);
339     });
340   }
341 
342   // Return true if the operation is containing a reference to a function
343   // containing operations for this cluster.
344   auto is_op_calling_func_for_cluster = [&](StringAttr cluster, Operation* op) {
345     auto funcs_for_cluster = tpu_funcs.find(cluster.getValue());
346     assert(funcs_for_cluster != tpu_funcs.end());
347     assert(!funcs_for_cluster->second.empty());
348     if (funcs_for_cluster->second.size() == 1) return false;
349     for (NamedAttribute attr : op->getAttrs()) {
350       auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
351       if (!symbol_ref) continue;
352       FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
353       if (!callee) continue;
354       if (funcs_for_cluster->second.count(callee)) return true;
355     }
356     return false;
357   };
358 
359   // Populates skip set with functions reachable from TPUPartionedCall ops.
360   const auto functions_to_skip =
361       FindTPUPartitionedCallReachableFunctions(getOperation());
362   for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
363     if (functions_to_skip.contains(func_op)) {
364       continue;
365     }
366 
367     func_op.walk([&](GraphOp graph) {
368       Block& graph_body = graph.GetBody();
369 
370       // Iterate until fixed point on the block, as it may contain multiple
371       // clusters.
372       bool changed = true;
373       while (changed) {
374         changed = false;
375         for (Operation& op : graph_body) {
376           if (failed(
377                   MergeIsland(is_op_calling_func_for_cluster, &op, &changed))) {
378             graph.emitError()
379                 << "Merging island failed: the TPU cluster likely "
380                 << "contains a cycle with non-TPU operations or has "
381                    "unsupported ops\n";
382             signalPassFailure();
383             return WalkResult::interrupt();
384           }
385           // If islands were merged, restart scanning the block from the
386           // beginning as we lost track of where to continue.
387           if (changed) break;
388         }
389       }
390       return WalkResult::advance();
391     });
392   }
393 }
394 
395 }  // namespace
396 
397 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass()398 CreateTFExecutorTPUV1IslandCoarseningPass() {
399   return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
400 }
401 
402 static PassRegistration<TpuV1BridgeExecutorIslandCoarsening> tpu_pass;
403 
404 }  // namespace tf_executor
405 }  // namespace mlir
406