1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <tuple>
18 #include <type_traits>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Block.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/IR/Visitors.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44
45 namespace mlir {
46 namespace TFTPU {
47
48 // This pass extracts a CPU computation cluster with `_xla_outside_compilation`
49 // annotation from the head or tail of a TPU cluster.
50
51 namespace {
52
53 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
54
HasOutsideCompilationAttribute(Operation * op)55 bool HasOutsideCompilationAttribute(Operation* op) {
56 return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
57 }
58
59 // Finds op that created a given value. If the value is a BlockArgument, this
60 // returns the owner of the Block.
GetOpOfValue(Value value)61 Operation* GetOpOfValue(Value value) {
62 if (auto block_arg = value.dyn_cast<BlockArgument>())
63 return block_arg.getOwner()->getParentOp();
64
65 return value.getDefiningOp();
66 }
67
68 // Checks if `op` is nested in `block`.
OpInBlock(Operation * op,Block * block)69 bool OpInBlock(Operation* op, Block* block) {
70 Block* op_block = op->getBlock();
71 while (op_block) {
72 if (op_block == block) return true;
73 if (auto* parent_op = op_block->getParentOp()) {
74 op_block = parent_op->getBlock();
75 } else {
76 break;
77 }
78 }
79 return false;
80 }
81
82 // Wraps block in a Launch. External uses of ops in the block will be return
83 // values of the Launch and remapped to the Launch results. If `before` is set
84 // to true, the Launch is created before `op`. Otherwise the Launch is created
85 // after `op`.
CreateLaunchForBlock(OpBuilder * builder,Operation * op,bool before,Block * launch_block,llvm::StringRef host_device)86 tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
87 bool before, Block* launch_block,
88 llvm::StringRef host_device) {
89 // Find results and result types of ops in block that needs to returned.
90 llvm::SmallVector<Value, 4> launch_results;
91 llvm::SmallVector<Type, 4> launch_result_types;
92 for (Operation& head_outside_compiled_op : *launch_block) {
93 for (Value result : head_outside_compiled_op.getResults()) {
94 bool has_external_uses = false;
95 for (Operation* user : result.getUsers()) {
96 if (OpInBlock(user, launch_block)) continue;
97 has_external_uses = true;
98 break;
99 }
100 if (has_external_uses) {
101 launch_results.push_back(result);
102 launch_result_types.push_back(result.getType());
103 }
104 }
105 }
106
107 before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
108 auto launch = builder->create<tf_device::LaunchOp>(
109 op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
110 launch.body().push_back(launch_block);
111
112 builder->setInsertionPointToEnd(&launch.GetBody());
113 builder->create<tf_device::ReturnOp>(op->getLoc(), launch_results);
114
115 return launch;
116 }
117
118 // Checks if an operation is a supported TPU embedding op.
IsEmbeddingOp(Operation * op)119 bool IsEmbeddingOp(Operation* op) {
120 return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
121 TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
122 TF::RecvTPUEmbeddingActivationsOp,
123 TF::SendTPUEmbeddingGradientsOp>(op);
124 }
125
126 // Returns a set of ops that are outside compiled and can be extracted to before
127 // the TPU computation. These ops are either connected to the inputs of the TPU
128 // computation or other ops that can be extracted, and have no operands from
129 // other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtHead(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster)130 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
131 const TF::SideEffectAnalysis& side_effect_analysis,
132 tf_device::ClusterOp cluster) {
133 const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
134 cluster->getParentOfType<FuncOp>());
135 Region* cluster_region = &cluster.body();
136 llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
137
138 auto cluster_ops = cluster.GetBody().without_terminator();
139 for (Operation& cluster_op : cluster_ops) {
140 if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
141 // An outside compiled op can be extracted if its operands are not from
142 // other ops in the cluster that cannot be extracted.
143
144 // Check if the side effecting op right before this side effecting op, if
145 // it is side effecting, can be head extracted. Because of op ordering due
146 // to side effects, if this is not true, this op cannot be head extracted.
147 // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
148 // is in a topological sort order and depending on that ordering, embedding
149 // ops may prevent other ops from being head extracted.
150 auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
151 if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) {
152 bool skip = false;
153 for (Operation* predecessor : llvm::reverse(predecessors)) {
154 if (IsEmbeddingOp(predecessor)) continue;
155 skip = !head_outside_compiled_ops.contains(predecessor);
156 break;
157 }
158 if (skip) continue;
159 }
160
161 auto walk_result = cluster_op.walk([&](Operation* op) {
162 for (Value operand : op->getOperands()) {
163 Operation* operand_op = GetOpOfValue(operand);
164 if (head_outside_compiled_ops.count(operand_op)) continue;
165
166 if (operand_op->getParentRegion() == cluster_region)
167 return WalkResult::interrupt();
168 }
169 return WalkResult::advance();
170 });
171
172 if (!walk_result.wasInterrupted())
173 head_outside_compiled_ops.insert(&cluster_op);
174 }
175
176 return head_outside_compiled_ops.takeVector();
177 }
178
179 // Moves head outside compiled ops into its own `tf_device.LaunchOp`
180 // computation before the cluster.
CreateHeadComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > head_outside_compiled_ops,llvm::StringRef host_device)181 void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
182 llvm::ArrayRef<Operation*> head_outside_compiled_ops,
183 llvm::StringRef host_device) {
184 Block* launch_block = new Block;
185 for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
186 head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
187 head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
188 }
189
190 tf_device::LaunchOp launch = CreateLaunchForBlock(
191 builder, cluster, /*before=*/true, launch_block, host_device);
192
193 for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
194 launch.getResults()))
195 replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
196 cluster.body());
197 }
198
199 // Extracts and move outside compiled ops that have no dependencies in the
200 // cluster to before the cluster.
LiftHeadOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,tf_device::ClusterOp cluster,std::string * host_device,bool * cluster_updated)201 mlir::LogicalResult LiftHeadOutsideCompiledOps(
202 OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
203 const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
204 std::string* host_device, bool* cluster_updated) {
205 llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
206 FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
207 if (head_outside_compiled_ops.empty()) return success();
208 if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
209 host_device)))
210 return failure();
211
212 CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
213 *host_device);
214
215 *cluster_updated = true;
216 return success();
217 }
218
219 // Fills `tail_outside_compiled_ops` with ops that are outside compiled and
220 // can be extracted to after the TPU computation, and `cluster_results` with new
221 // results of the cluster. These ops are either connected to the output of the
222 // TPU computation or other ops that can be extracted, and have no results used
223 // by other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtTailAndClusterResults(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster,llvm::SmallVectorImpl<Operation * > * tail_outside_compiled_ops,llvm::SmallVectorImpl<Value> * cluster_results)224 void FindOutsideCompiledOpsAtTailAndClusterResults(
225 const TF::SideEffectAnalysis& side_effect_analysis,
226 tf_device::ClusterOp cluster,
227 llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
228 llvm::SmallVectorImpl<Value>* cluster_results) {
229 const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
230 cluster->getParentOfType<FuncOp>());
231 Region* cluster_region = &cluster.body();
232 llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
233 Operation* terminator = cluster.GetBody().getTerminator();
234 llvm::SmallSetVector<Value, 4> cluster_results_set;
235 cluster_results_set.insert(terminator->getOperands().begin(),
236 terminator->getOperands().end());
237
238 auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
239 for (Operation& cluster_op : cluster_ops) {
240 if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
241
242 // Check if the side effecting op right after this side effecting op, if
243 // it is side effecting, can be tail extracted. Because of op ordering due
244 // to side effects, if this is not true, this op cannot be tail extracted.
245 // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
246 // is in a topological sort order and depending on that ordering, embedding
247 // ops may prevent other ops from being tail extracted.
248 auto successors = analysis.DirectControlSuccessors(
249 &cluster_op, [&terminator](Operation* op) { return op != terminator; });
250 if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) {
251 bool skip = false;
252 for (Operation* successor : successors) {
253 if (IsEmbeddingOp(successor)) continue;
254 skip = !tail_outside_compiled_ops_set.contains(successor);
255 break;
256 }
257 if (skip) continue;
258 }
259
260 llvm::SmallVector<int, 4> results_to_forward;
261 bool can_be_extracted =
262 llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
263 return op == terminator || tail_outside_compiled_ops_set.count(op);
264 });
265 if (!can_be_extracted) continue;
266
267 // Collect operands of cluster op that are generated within the cluster.
268 // These values should be returned by the cluster.
269 cluster_op.walk([&](Operation* op) {
270 for (Value operand : op->getOperands()) {
271 Operation* operand_op = GetOpOfValue(operand);
272 if (operand_op->getParentRegion() == cluster_region)
273 cluster_results_set.insert(operand);
274 }
275 });
276
277 // Remove results of op to be extracted as there are no uses in the cluster.
278 for (Value result : cluster_op.getResults())
279 cluster_results_set.remove(result);
280 tail_outside_compiled_ops_set.insert(&cluster_op);
281 }
282
283 *tail_outside_compiled_ops = tail_outside_compiled_ops_set.takeVector();
284 *cluster_results = cluster_results_set.takeVector();
285 }
286
287 // Moves tail outside compiled ops into its own `tf_device.LaunchOp`
288 // computation after the cluster.
CreateTailComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > tail_outside_compiled_ops,llvm::StringRef host_device)289 void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
290 llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
291 llvm::StringRef host_device) {
292 Block* launch_block = new Block;
293 for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
294 tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
295 tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
296 }
297
298 tf_device::LaunchOp launch = CreateLaunchForBlock(
299 builder, cluster, /*before=*/false, launch_block, host_device);
300
301 auto operand_not_in_launch = [&](OpOperand& operand) {
302 return !launch.getOperation()->isProperAncestor(operand.getOwner());
303 };
304 for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
305 launch.getResults()))
306 std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
307 operand_not_in_launch);
308 }
309
310 // Updates cluster with updated cluster results after extracting tail outside
311 // compiled ops.
UpdateClusterResults(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Value> new_cluster_results)312 tf_device::ClusterOp UpdateClusterResults(
313 OpBuilder* builder, tf_device::ClusterOp cluster,
314 llvm::ArrayRef<Value> new_cluster_results) {
315 Operation* old_terminator = cluster.GetBody().getTerminator();
316 builder->setInsertionPoint(old_terminator);
317 builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
318 new_cluster_results);
319 old_terminator->erase();
320
321 builder->setInsertionPoint(cluster);
322 llvm::SmallVector<Type, 4> new_cluster_result_types;
323 new_cluster_result_types.reserve(new_cluster_results.size());
324 for (const auto& new_cluster_result : new_cluster_results)
325 new_cluster_result_types.push_back(new_cluster_result.getType());
326
327 auto new_cluster = builder->create<tf_device::ClusterOp>(
328 cluster.getLoc(), new_cluster_result_types,
329 /*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
330 new_cluster.body().takeBody(cluster.body());
331
332 auto operand_not_in_cluster = [&](OpOperand& operand) {
333 return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
334 };
335 for (auto result :
336 llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
337 new_cluster.getResults()))
338 std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
339 operand_not_in_cluster);
340
341 cluster.erase();
342 return new_cluster;
343 }
344
345 // Extracts and move outside compiled ops that do not create dependencies in the
346 // cluster to after the cluster.
LiftTailOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,std::string host_device,tf_device::ClusterOp * cluster,bool * cluster_updated)347 mlir::LogicalResult LiftTailOutsideCompiledOps(
348 OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
349 const mlir::TF::RuntimeDevices& devices, std::string host_device,
350 tf_device::ClusterOp* cluster, bool* cluster_updated) {
351 llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
352 llvm::SmallVector<Value, 4> cluster_results;
353 FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
354 &tail_outside_compiled_ops,
355 &cluster_results);
356 if (tail_outside_compiled_ops.empty()) return success();
357
358 if (host_device.empty())
359 if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
360 &host_device)))
361 return failure();
362
363 // Forward all results of cluster first. These results will be remapped once
364 // a new cluster is formed.
365 cluster->replaceAllUsesWith(
366 cluster->GetBody().getTerminator()->getOperands());
367
368 CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
369 host_device);
370
371 *cluster = UpdateClusterResults(builder, *cluster, cluster_results);
372
373 *cluster_updated = true;
374 return success();
375 }
376
377 // Removes aliased outputs in cluster from ops outside of cluster.
RemoveClusterAliasedOutputs(OpBuilder * builder,tf_device::ClusterOp cluster)378 void RemoveClusterAliasedOutputs(OpBuilder* builder,
379 tf_device::ClusterOp cluster) {
380 llvm::SmallVector<Value, 4> used_old_cluster_results;
381 llvm::SmallVector<Value, 4> new_cluster_results;
382 llvm::SmallVector<Type, 4> new_cluster_result_types;
383 Operation* cluster_terminator = cluster.GetBody().getTerminator();
384 for (auto result :
385 llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
386 Value cluster_terminator_operand = std::get<0>(result);
387 if (cluster_terminator_operand.getDefiningOp() &&
388 cluster.getOperation()->isProperAncestor(
389 cluster_terminator_operand.getDefiningOp())) {
390 new_cluster_results.push_back(cluster_terminator_operand);
391 new_cluster_result_types.push_back(cluster_terminator_operand.getType());
392 used_old_cluster_results.push_back(std::get<1>(result));
393 } else {
394 std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
395 }
396 }
397
398 if (new_cluster_results.size() == cluster.getNumResults()) return;
399
400 builder->setInsertionPoint(cluster);
401 auto new_cluster = builder->create<tf_device::ClusterOp>(
402 cluster.getLoc(), new_cluster_result_types,
403 /*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
404 new_cluster.body().takeBody(cluster.body());
405 new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
406
407 for (auto result :
408 llvm::zip(used_old_cluster_results, new_cluster.getResults()))
409 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
410
411 cluster.erase();
412 }
413
414 struct TPUExtractHeadTailOutsideCompilation
415 : public PassWrapper<TPUExtractHeadTailOutsideCompilation,
416 OperationPass<ModuleOp>> {
417 void runOnOperation() override;
418 };
419
runOnOperation()420 void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
421 auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
422 // Get runtime devices information from the closest parent module.
423 auto module = getOperation();
424 mlir::TF::RuntimeDevices devices;
425 if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
426 return signalPassFailure();
427
428 OpBuilder builder(&getContext());
429 llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
430 module.walk(
431 [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
432
433 for (tf_device::ClusterOp cluster : clusters) {
434 std::string host_device;
435 bool cluster_updated = false;
436 if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
437 devices, cluster, &host_device,
438 &cluster_updated)) ||
439 failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
440 devices, host_device, &cluster,
441 &cluster_updated)))
442 return signalPassFailure();
443 if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
444 }
445 }
446
447 } // anonymous namespace
448
449 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass()450 CreateTPUExtractHeadTailOutsideCompilationPass() {
451 return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
452 }
453
454 static PassRegistration<TPUExtractHeadTailOutsideCompilation> pass(
455 "tf-tpu-extract-head-tail-outside-compilation",
456 "Extracts TPU head or tail outside compilation to separate "
457 "parallel_execute.");
458
459 } // namespace TFTPU
460 } // namespace mlir
461