1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass takes TensorFlow executor dialect IslandOps and
17 // merges the one that contains operation marked to run on TPU.
18
19 #include <algorithm>
20 #include <iterator>
21 #include <queue>
22 #include <tuple>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/None.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Debug.h"
34 #include "mlir/IR/Attributes.h" // from @llvm-project
35 #include "mlir/IR/Block.h" // from @llvm-project
36 #include "mlir/IR/Builders.h" // from @llvm-project
37 #include "mlir/IR/Location.h" // from @llvm-project
38 #include "mlir/IR/Operation.h" // from @llvm-project
39 #include "mlir/IR/SymbolTable.h" // from @llvm-project
40 #include "mlir/IR/UseDefLists.h" // from @llvm-project
41 #include "mlir/IR/Visitors.h" // from @llvm-project
42 #include "mlir/Pass/Pass.h" // from @llvm-project
43 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
44 #include "mlir/Support/LLVM.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
46 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
47 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
48 #include "tensorflow/core/platform/logging.h"
49
50 #define DEBUG_TYPE "tf-executor-tpu-v1-island-coarsening"
51
52 namespace mlir {
53 namespace tf_executor {
54
55 namespace {
56
57 constexpr llvm::StringRef kTpuReplicateAttr = "_tpu_replicate";
58 constexpr llvm::StringRef kTpuStatusAttr = "_tpu_compilation_status";
59
60 // This pass is a variant of the island coarsening that is limited to
61 // TPU-annotated operations and intended to preserve backward compatibility with
62 // TFv1.
63 struct TpuV1BridgeExecutorIslandCoarsening
64 : public PassWrapper<TpuV1BridgeExecutorIslandCoarsening,
65 OperationPass<ModuleOp>> {
getArgumentmlir::tf_executor::__anon70a6e7720111::TpuV1BridgeExecutorIslandCoarsening66 StringRef getArgument() const final {
67 return "tf-executor-tpu-v1-island-coarsening";
68 }
69
getDescriptionmlir::tf_executor::__anon70a6e7720111::TpuV1BridgeExecutorIslandCoarsening70 StringRef getDescription() const final {
71 return "Merges TPU clusters IslandOps, intended for V1 compatibility mode";
72 }
73
74 void runOnOperation() override;
75 };
76
77 // Sorts the operations in the provided range to enforce dominance.
78 // This is useful after fusing / reorganizing Operations in a block and later
79 // needing to readjust the ordering to ensure dominance.
SortTopologically(Block::iterator begin,Block::iterator end)80 LogicalResult SortTopologically(Block::iterator begin, Block::iterator end) {
81 Block* block = begin->getBlock();
82 // Either sort from `begin` to end of block or both `begin` and
83 // `end` should belong to the same block.
84 assert(end == block->end() ||
85 end->getBlock() == block && "ops must be in the same block");
86
87 // Track the ops that still need to be scheduled in a set.
88 SmallPtrSet<Operation*, 16> unscheduled_ops;
89 for (Operation& op : llvm::make_range(begin, end))
90 unscheduled_ops.insert(&op);
91
92 Block::iterator last_scheduled_op = begin;
93 while (!unscheduled_ops.empty()) {
94 bool scheduled_at_least_once = false;
95 // Loop over the ops that are not sorted yet, try to find the ones "ready",
96 // i.e. the ones for which there aren't any operand produced by an op in the
97 // set, and "schedule" it (move it before the last_scheduled_op).
98 for (Operation& op : llvm::make_range(last_scheduled_op, end)) {
99 WalkResult ready_to_schedule = op.walk([&](Operation* nested_op) {
100 for (Value operand : nested_op->getOperands()) {
101 Operation* defining_op = operand.getDefiningOp();
102 if (!defining_op) continue;
103 Operation* producer_in_block =
104 block->findAncestorOpInBlock(*defining_op);
105 if (producer_in_block && producer_in_block != &op &&
106 unscheduled_ops.count(producer_in_block)) {
107 // Found an operand that isn't scheduled yet, interrupt the walk.
108 return WalkResult::interrupt();
109 }
110 }
111 return WalkResult::advance();
112 });
113 if (ready_to_schedule.wasInterrupted()) continue;
114 unscheduled_ops.erase(&op);
115 if (Block::iterator(op) != last_scheduled_op)
116 op.moveBefore(block, last_scheduled_op);
117 else
118 ++last_scheduled_op;
119 scheduled_at_least_once = true;
120 }
121 if (!scheduled_at_least_once) return failure();
122 }
123 return success();
124 }
125
126 // Looks for an IslandOp that wraps a single operation tagged with the
127 // _tpu_replicate attribute, and merges it with all the following operations in
128 // the block. Sets the `changed` boolean to true if any island is merged.
129 // Returns a failure if a cycle prevents the merge from happening correctly
130 // without breaking dominance. The IR is left in invalid state in case of
131 // failure.
MergeIsland(llvm::function_ref<bool (StringAttr,Operation *)> is_op_calling_func_for_cluster,Operation * op,bool * changed)132 LogicalResult MergeIsland(llvm::function_ref<bool(StringAttr, Operation*)>
133 is_op_calling_func_for_cluster,
134 Operation* op, bool* changed) {
135 // Find the first island wrapping a single operation with the `_tpu_replicate`
136 // attribute, it'll be used as the root of the algorithm to find the other
137 // operations that are part of the same cluster.
138 IslandOp island = dyn_cast<IslandOp>(*op);
139 if (!island || !island.WrapsSingleOp()) return success();
140 Operation& wrapped_op = island.GetBody().front();
141
142 // TODO(b/188046643): Conservatively fail until pass is extended to fuse
143 // chains of these ops.
144 if (isa<TF::TPUPartitionedInputOp, TF::TPUPartitionedOutputOp>(wrapped_op)) {
145 return failure();
146 }
147
148 StringAttr cluster_name =
149 wrapped_op.getAttrOfType<StringAttr>(kTpuReplicateAttr);
150 if (!cluster_name)
151 cluster_name = wrapped_op.getAttrOfType<StringAttr>(kTpuStatusAttr);
152 if (!cluster_name) return success();
153
154 // We found a _tpu_replicate, let's build an island for the full cluster!
155 LLVM_DEBUG(llvm::dbgs() << "Processing candidate island: "
156 << *island.getOperation() << "\n");
157
158 // Collect the islands to merge together in this new cluster starting with the
159 // given island.
160 SmallVector<IslandOp, 16> islands;
161 SmallPtrSet<Operation*, 16> wrapped_ops;
162 for (Operation& candidate_op : llvm::make_early_inc_range(
163 llvm::make_range(op->getIterator(), op->getBlock()->end()))) {
164 IslandOp candidate_island = dyn_cast<IslandOp>(candidate_op);
165 if (!candidate_island || !candidate_island.WrapsSingleOp()) continue;
166 // Check if we have an operation with the expected attribute.
167 Operation& candidate_wrapped_op = candidate_island.GetBody().front();
168
169 // TODO(b/188046643): Conservatively fail until pass is extended to fuse
170 // chains of these ops.
171 if (isa<TF::TPUPartitionedInputOp, TF::TPUPartitionedOutputOp>(
172 candidate_wrapped_op)) {
173 return failure();
174 }
175
176 StringAttr candidate_cluster_name =
177 candidate_wrapped_op.getAttrOfType<StringAttr>(kTpuReplicateAttr);
178 if (!candidate_cluster_name)
179 candidate_cluster_name =
180 candidate_wrapped_op.getAttrOfType<StringAttr>(kTpuStatusAttr);
181 if (candidate_cluster_name != cluster_name &&
182 !is_op_calling_func_for_cluster(cluster_name, &candidate_wrapped_op))
183 continue;
184
185 // Look at captured operands to bring-in ReplicatedInputOp in the
186 // island as well. Consider pulling in tf.Const, some optimizations can
187 // benefit from this.
188 for (Value operand : candidate_wrapped_op.getOperands()) {
189 IslandOp wrapper = dyn_cast_or_null<IslandOp>(operand.getDefiningOp());
190 if (!wrapper || !wrapper.WrapsSingleOp()) continue;
191 Operation& wrapped_op = wrapper.GetBody().front();
192 if (!isa<TF::TPUReplicatedInputOp>(wrapped_op)) continue;
193 if (wrapped_ops.count(&wrapped_op)) continue;
194 wrapped_ops.insert(&wrapped_op);
195 islands.push_back(wrapper);
196 }
197 islands.push_back(candidate_island);
198 wrapped_ops.insert(&candidate_wrapped_op);
199
200 // Look at results to bring-in ReplicatedOutputOp in the island as well.
201 for (Value result : candidate_island.getResults()) {
202 for (OpOperand use : result.getUsers()) {
203 Operation* user = use.getOwner();
204 if (!isa<TF::TPUReplicatedOutputOp>(user)) continue;
205 assert(!wrapped_ops.count(user) &&
206 "unexpected already processed TPUReplicatedOutputOp");
207 wrapped_ops.insert(user);
208 islands.push_back(cast<IslandOp>(user->getParentOp()));
209 }
210 }
211 }
212
213 // If no other island was found to merge with the existing one, just
214 // move on.
215 if (islands.size() <= 1) return success();
216
217 *changed = true;
218 auto first_op_after =
219 std::next(Block::iterator(islands.back().getOperation()));
220
221 // Compute the result of the merged island, these are the values produced by
222 // the islands that are merged if they have a use in an island not merged,
223 // i.e. a value that escapes.
224 llvm::SmallVector<Type, 4> result_types;
225 for (IslandOp new_op : islands) {
226 for (Value result : new_op.outputs()) {
227 if (llvm::any_of(result.getUsers(), [&](OpOperand user) {
228 return !wrapped_ops.count(user.getOwner());
229 }))
230 result_types.push_back(result.getType());
231 }
232 }
233
234 IslandOp new_island = OpBuilder(island).create<IslandOp>(
235 island.getLoc(), result_types,
236 /*control=*/ControlType::get(island.getContext()),
237 /*controlInputs=*/island.getOperands());
238 new_island.body().push_back(new Block);
239
240 // Move the operations in the new island, gather the results of the new yield.
241 Block& island_body = new_island.GetBody();
242 SmallVector<Value, 16> yield_operands;
243 for (IslandOp island : islands) {
244 Operation& wrapped_op = island.GetBody().front();
245 wrapped_op.moveBefore(&island_body, island_body.end());
246
247 // For every result of the wrapped_op, it needs to get passed to the yield
248 // operation, only if it escapes the island.
249 for (auto result : llvm::zip(island.outputs(), wrapped_op.getResults())) {
250 if (llvm::any_of(std::get<0>(result).getUsers(), [&](OpOperand user) {
251 return !wrapped_ops.count(user.getOwner());
252 }))
253 yield_operands.push_back(std::get<1>(result));
254 }
255 }
256 OpBuilder::atBlockEnd(&island_body)
257 .create<YieldOp>(new_island.getLoc(), yield_operands);
258
259 // remap results of the new islands to the user outside of the island.
260 int current_result = 0;
261 Value control = new_island.control();
262 for (IslandOp island : islands) {
263 YieldOp yield_op = island.GetYield();
264 for (auto idx_result : llvm::enumerate(island.outputs())) {
265 Value result = idx_result.value();
266
267 bool has_external_use = false;
268 for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
269 if (wrapped_ops.count(use.getOwner()))
270 use.set(yield_op.getOperand(idx_result.index()));
271 else
272 has_external_use = true;
273 }
274 if (has_external_use) {
275 result.replaceAllUsesWith(new_island.getResult(current_result));
276 ++current_result;
277 }
278 }
279 island.control().replaceAllUsesWith(control);
280 island.erase();
281 }
282
283 // Ensure dominance by sorting the range of islands that were merged.
284 return SortTopologically(Block::iterator(new_island.getOperation()),
285 first_op_after);
286 }
287
288 // Returns all functions that can be reached from TPUPartitionedCall ops.
FindTPUPartitionedCallReachableFunctions(ModuleOp module)289 SmallPtrSet<Operation*, 16> FindTPUPartitionedCallReachableFunctions(
290 ModuleOp module) {
291 SymbolTableCollection table;
292 SymbolUserMap symbol_map(table, module);
293 llvm::DenseMap<FuncOp, llvm::DenseSet<FuncOp>> caller_callee_map;
294 // Creates work queue for determining reachability below.
295 std::queue<FuncOp> function_worklist;
296
297 for (auto func : module.getOps<FuncOp>()) {
298 for (auto user : symbol_map.getUsers(func)) {
299 // Populates work queue with func ops called from TPUPartionedCall.
300 if (llvm::isa<TF::TPUPartitionedCallOp>(user)) {
301 function_worklist.push(func);
302 }
303 // Populates caller to called func map.
304 if (FuncOp caller = user->getParentOfType<FuncOp>()) {
305 caller_callee_map[caller].insert(func);
306 }
307 }
308 }
309
310 // Determines reached ops starting from TPUPartionedCall ops
311 // and iteratively descending through called ops.
312 SmallPtrSet<Operation*, 16> reachable_functions;
313 while (!function_worklist.empty()) {
314 FuncOp caller = function_worklist.front();
315 function_worklist.pop();
316 if (reachable_functions.insert(caller).second) {
317 for (auto callee : caller_callee_map[caller]) {
318 function_worklist.push(callee);
319 }
320 }
321 }
322 return reachable_functions;
323 }
324
runOnOperation()325 void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() {
326 SymbolTable symbol_table(getOperation());
327
328 // Map tpu cluster names to the functions that contain operations for this
329 // cluster.
330 DenseMap<StringRef, DenseSet<FuncOp>> tpu_funcs;
331 for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
332 func_op.walk([&](Operation* op) {
333 StringAttr cluster_name =
334 op->getAttrOfType<StringAttr>(kTpuReplicateAttr);
335 if (!cluster_name)
336 cluster_name = op->getAttrOfType<StringAttr>(kTpuStatusAttr);
337 if (!cluster_name) return;
338 tpu_funcs[cluster_name.getValue()].insert(func_op);
339 });
340 }
341
342 // Return true if the operation is containing a reference to a function
343 // containing operations for this cluster.
344 auto is_op_calling_func_for_cluster = [&](StringAttr cluster, Operation* op) {
345 auto funcs_for_cluster = tpu_funcs.find(cluster.getValue());
346 assert(funcs_for_cluster != tpu_funcs.end());
347 assert(!funcs_for_cluster->second.empty());
348 if (funcs_for_cluster->second.size() == 1) return false;
349 for (NamedAttribute attr : op->getAttrs()) {
350 auto symbol_ref = attr.second.dyn_cast<FlatSymbolRefAttr>();
351 if (!symbol_ref) continue;
352 FuncOp callee = symbol_table.lookup<FuncOp>(symbol_ref.getValue());
353 if (!callee) continue;
354 if (funcs_for_cluster->second.count(callee)) return true;
355 }
356 return false;
357 };
358
359 // Populates skip set with functions reachable from TPUPartionedCall ops.
360 const auto functions_to_skip =
361 FindTPUPartitionedCallReachableFunctions(getOperation());
362 for (FuncOp func_op : getOperation().getOps<FuncOp>()) {
363 if (functions_to_skip.contains(func_op)) {
364 continue;
365 }
366
367 func_op.walk([&](GraphOp graph) {
368 Block& graph_body = graph.GetBody();
369
370 // Iterate until fixed point on the block, as it may contain multiple
371 // clusters.
372 bool changed = true;
373 while (changed) {
374 changed = false;
375 for (Operation& op : graph_body) {
376 if (failed(
377 MergeIsland(is_op_calling_func_for_cluster, &op, &changed))) {
378 graph.emitError()
379 << "Merging island failed: the TPU cluster likely "
380 << "contains a cycle with non-TPU operations or has "
381 "unsupported ops\n";
382 signalPassFailure();
383 return WalkResult::interrupt();
384 }
385 // If islands were merged, restart scanning the block from the
386 // beginning as we lost track of where to continue.
387 if (changed) break;
388 }
389 }
390 return WalkResult::advance();
391 });
392 }
393 }
394
395 } // namespace
396
397 std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorTPUV1IslandCoarseningPass()398 CreateTFExecutorTPUV1IslandCoarseningPass() {
399 return std::make_unique<TpuV1BridgeExecutorIslandCoarsening>();
400 }
401
402 static PassRegistration<TpuV1BridgeExecutorIslandCoarsening> tpu_pass;
403
404 } // namespace tf_executor
405 } // namespace mlir
406