1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <tuple>
18 #include <type_traits>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Block.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/IR/Visitors.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
45
46 namespace mlir {
47 namespace TFTPU {
48
49 // This pass extracts a CPU computation cluster with `_xla_outside_compilation`
50 // annotation from the head or tail of a TPU cluster.
51
52 namespace {
53
54 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
55
HasOutsideCompilationAttribute(Operation * op)56 bool HasOutsideCompilationAttribute(Operation* op) {
57 return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
58 }
59
60 // Finds op that created a given value. If the value is a BlockArgument, this
61 // returns the owner of the Block.
GetOpOfValue(Value value)62 Operation* GetOpOfValue(Value value) {
63 if (auto block_arg = value.dyn_cast<BlockArgument>())
64 return block_arg.getOwner()->getParentOp();
65
66 return value.getDefiningOp();
67 }
68
69 // Checks if `op` is nested in `block`.
OpInBlock(Operation * op,Block * block)70 bool OpInBlock(Operation* op, Block* block) {
71 Block* op_block = op->getBlock();
72 while (op_block) {
73 if (op_block == block) return true;
74 if (auto* parent_op = op_block->getParentOp()) {
75 op_block = parent_op->getBlock();
76 } else {
77 break;
78 }
79 }
80 return false;
81 }
82
83 // Wraps block in a Launch. External uses of ops in the block will be return
84 // values of the Launch and remapped to the Launch results. If `before` is set
85 // to true, the Launch is created before `op`. Otherwise the Launch is created
86 // after `op`.
CreateLaunchForBlock(OpBuilder * builder,Operation * op,bool before,Block * launch_block,llvm::StringRef host_device)87 tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
88 bool before, Block* launch_block,
89 llvm::StringRef host_device) {
90 // Find results and result types of ops in block that needs to returned.
91 llvm::SmallVector<Value, 4> launch_results;
92 llvm::SmallVector<Type, 4> launch_result_types;
93 for (Operation& head_outside_compiled_op : *launch_block) {
94 for (Value result : head_outside_compiled_op.getResults()) {
95 bool has_external_uses = false;
96 for (Operation* user : result.getUsers()) {
97 if (OpInBlock(user, launch_block)) continue;
98 has_external_uses = true;
99 break;
100 }
101 if (has_external_uses) {
102 launch_results.push_back(result);
103 launch_result_types.push_back(result.getType());
104 }
105 }
106 }
107
108 before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
109 auto launch = builder->create<tf_device::LaunchOp>(
110 op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
111 launch.body().push_back(launch_block);
112
113 builder->setInsertionPointToEnd(&launch.GetBody());
114 builder->create<tf_device::ReturnOp>(op->getLoc(), launch_results);
115
116 return launch;
117 }
118
119 // Checks if an operation is a supported TPU embedding op.
IsEmbeddingOp(Operation * op)120 bool IsEmbeddingOp(Operation* op) {
121 return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
122 TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
123 TF::RecvTPUEmbeddingActivationsOp,
124 TF::SendTPUEmbeddingGradientsOp>(op);
125 }
126
127 // Returns a set of ops that are outside compiled and can be extracted to before
128 // the TPU computation. These ops are either connected to the inputs of the TPU
129 // computation or other ops that can be extracted, and have no operands from
130 // other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtHead(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster)131 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
132 const TF::SideEffectAnalysis& side_effect_analysis,
133 tf_device::ClusterOp cluster) {
134 const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
135 cluster->getParentOfType<FuncOp>());
136 Region* cluster_region = &cluster.body();
137 llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
138
139 auto cluster_ops = cluster.GetBody().without_terminator();
140 for (Operation& cluster_op : cluster_ops) {
141 if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
142 // An outside compiled op can be extracted if its operands are not from
143 // other ops in the cluster that cannot be extracted.
144
145 // Check if the side effecting op right before this side effecting op, if
146 // it is side effecting, can be head extracted. Because of op ordering due
147 // to side effects, if this is not true, this op cannot be head extracted.
148 // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
149 // is in a topological sort order and depending on that ordering, embedding
150 // ops may prevent other ops from being head extracted.
151 auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
152 if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) {
153 bool skip = false;
154 for (Operation* predecessor : llvm::reverse(predecessors)) {
155 if (IsEmbeddingOp(predecessor)) continue;
156 skip = !head_outside_compiled_ops.contains(predecessor);
157 break;
158 }
159 if (skip) continue;
160 }
161
162 auto walk_result = cluster_op.walk([&](Operation* op) {
163 for (Value operand : op->getOperands()) {
164 Operation* operand_op = GetOpOfValue(operand);
165 if (head_outside_compiled_ops.count(operand_op) ||
166 operand_op == &cluster_op)
167 continue;
168
169 if (operand_op->getParentRegion() == cluster_region)
170 return WalkResult::interrupt();
171 }
172 return WalkResult::advance();
173 });
174
175 if (!walk_result.wasInterrupted())
176 head_outside_compiled_ops.insert(&cluster_op);
177 }
178
179 return head_outside_compiled_ops.takeVector();
180 }
181
182 // Moves head outside compiled ops into its own `tf_device.LaunchOp`
183 // computation before the cluster.
CreateHeadComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > head_outside_compiled_ops,llvm::StringRef host_device)184 void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
185 llvm::ArrayRef<Operation*> head_outside_compiled_ops,
186 llvm::StringRef host_device) {
187 Block* launch_block = new Block;
188 for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
189 head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
190 head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
191 }
192
193 tf_device::LaunchOp launch = CreateLaunchForBlock(
194 builder, cluster, /*before=*/true, launch_block, host_device);
195
196 for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
197 launch.getResults()))
198 replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
199 cluster.body());
200 }
201
202 // Extracts and move outside compiled ops that have no dependencies in the
203 // cluster to before the cluster.
LiftHeadOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,tf_device::ClusterOp cluster,std::string * host_device,bool * cluster_updated)204 mlir::LogicalResult LiftHeadOutsideCompiledOps(
205 OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
206 const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
207 std::string* host_device, bool* cluster_updated) {
208 llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
209 FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
210 if (head_outside_compiled_ops.empty()) return success();
211 if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
212 host_device)))
213 return failure();
214
215 CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
216 *host_device);
217
218 *cluster_updated = true;
219 return success();
220 }
221
222 // Fills `tail_outside_compiled_ops` with ops that are outside compiled and
223 // can be extracted to after the TPU computation, and `cluster_results` with new
224 // results of the cluster. These ops are either connected to the output of the
225 // TPU computation or other ops that can be extracted, and have no results used
226 // by other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtTailAndClusterResults(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster,llvm::SmallVectorImpl<Operation * > * tail_outside_compiled_ops,llvm::SmallVectorImpl<Value> * cluster_results)227 void FindOutsideCompiledOpsAtTailAndClusterResults(
228 const TF::SideEffectAnalysis& side_effect_analysis,
229 tf_device::ClusterOp cluster,
230 llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
231 llvm::SmallVectorImpl<Value>* cluster_results) {
232 const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
233 cluster->getParentOfType<FuncOp>());
234 Region* cluster_region = &cluster.body();
235 llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
236 Operation* terminator = cluster.GetBody().getTerminator();
237 llvm::SmallSetVector<Value, 4> cluster_results_set;
238 cluster_results_set.insert(terminator->getOperands().begin(),
239 terminator->getOperands().end());
240
241 auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
242 for (Operation& cluster_op : cluster_ops) {
243 if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
244
245 // Check if the side effecting op right after this side effecting op, if
246 // it is side effecting, can be tail extracted. Because of op ordering due
247 // to side effects, if this is not true, this op cannot be tail extracted.
248 // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
249 // is in a topological sort order and depending on that ordering, embedding
250 // ops may prevent other ops from being tail extracted.
251 auto successors = analysis.DirectControlSuccessors(
252 &cluster_op, [&terminator](Operation* op) { return op != terminator; });
253 if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) {
254 bool skip = false;
255 for (Operation* successor : successors) {
256 if (IsEmbeddingOp(successor)) continue;
257 skip = !tail_outside_compiled_ops_set.contains(successor);
258 break;
259 }
260 if (skip) continue;
261 }
262
263 llvm::SmallVector<int, 4> results_to_forward;
264 bool can_be_extracted =
265 llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
266 return op == terminator || tail_outside_compiled_ops_set.count(op);
267 });
268 if (!can_be_extracted) continue;
269
270 // Collect operands of cluster op that are generated within the cluster.
271 // These values should be returned by the cluster.
272 cluster_op.walk([&](Operation* op) {
273 for (Value operand : op->getOperands()) {
274 Operation* operand_op = GetOpOfValue(operand);
275 if (operand_op->getParentRegion() == cluster_region)
276 cluster_results_set.insert(operand);
277 }
278 });
279
280 // Remove results of op to be extracted as there are no uses in the cluster.
281 for (Value result : cluster_op.getResults())
282 cluster_results_set.remove(result);
283 // Insert all ops including nested ops for checking outputs/side effects.
284 cluster_op.walk(
285 [&](Operation* op) { tail_outside_compiled_ops_set.insert(op); });
286
287 // Only add top level ops to output vector.
288 tail_outside_compiled_ops->push_back(&cluster_op);
289 }
290
291 *cluster_results = cluster_results_set.takeVector();
292 }
293
294 // Moves tail outside compiled ops into its own `tf_device.LaunchOp`
295 // computation after the cluster.
CreateTailComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > tail_outside_compiled_ops,llvm::StringRef host_device)296 void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
297 llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
298 llvm::StringRef host_device) {
299 Block* launch_block = new Block;
300 for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
301 tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
302 tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
303 }
304
305 tf_device::LaunchOp launch = CreateLaunchForBlock(
306 builder, cluster, /*before=*/false, launch_block, host_device);
307
308 auto operand_not_in_launch = [&](OpOperand& operand) {
309 return !launch.getOperation()->isProperAncestor(operand.getOwner());
310 };
311 for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
312 launch.getResults()))
313 std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
314 operand_not_in_launch);
315 }
316
317 // Updates cluster with updated cluster results after extracting tail outside
318 // compiled ops.
UpdateClusterResults(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Value> new_cluster_results)319 tf_device::ClusterOp UpdateClusterResults(
320 OpBuilder* builder, tf_device::ClusterOp cluster,
321 llvm::ArrayRef<Value> new_cluster_results) {
322 Operation* old_terminator = cluster.GetBody().getTerminator();
323 builder->setInsertionPoint(old_terminator);
324 builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
325 new_cluster_results);
326 old_terminator->erase();
327
328 builder->setInsertionPoint(cluster);
329 llvm::SmallVector<Type, 4> new_cluster_result_types;
330 new_cluster_result_types.reserve(new_cluster_results.size());
331 for (const auto& new_cluster_result : new_cluster_results)
332 new_cluster_result_types.push_back(new_cluster_result.getType());
333
334 auto new_cluster = builder->create<tf_device::ClusterOp>(
335 cluster.getLoc(), new_cluster_result_types,
336 /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
337 new_cluster.body().takeBody(cluster.body());
338
339 auto operand_not_in_cluster = [&](OpOperand& operand) {
340 return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
341 };
342 for (auto result :
343 llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
344 new_cluster.getResults()))
345 std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
346 operand_not_in_cluster);
347
348 cluster.erase();
349 return new_cluster;
350 }
351
352 // Extracts and move outside compiled ops that do not create dependencies in the
353 // cluster to after the cluster.
LiftTailOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,std::string host_device,tf_device::ClusterOp * cluster,bool * cluster_updated)354 mlir::LogicalResult LiftTailOutsideCompiledOps(
355 OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
356 const mlir::TF::RuntimeDevices& devices, std::string host_device,
357 tf_device::ClusterOp* cluster, bool* cluster_updated) {
358 llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
359 llvm::SmallVector<Value, 4> cluster_results;
360 FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
361 &tail_outside_compiled_ops,
362 &cluster_results);
363 if (tail_outside_compiled_ops.empty()) return success();
364
365 if (host_device.empty())
366 if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
367 &host_device)))
368 return failure();
369
370 // Forward all results of cluster first. These results will be remapped once
371 // a new cluster is formed.
372 cluster->replaceAllUsesWith(
373 cluster->GetBody().getTerminator()->getOperands());
374
375 CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
376 host_device);
377
378 *cluster = UpdateClusterResults(builder, *cluster, cluster_results);
379
380 *cluster_updated = true;
381 return success();
382 }
383
384 // Removes aliased outputs in cluster from ops outside of cluster.
RemoveClusterAliasedOutputs(OpBuilder * builder,tf_device::ClusterOp cluster)385 void RemoveClusterAliasedOutputs(OpBuilder* builder,
386 tf_device::ClusterOp cluster) {
387 llvm::SmallVector<Value, 4> used_old_cluster_results;
388 llvm::SmallVector<Value, 4> new_cluster_results;
389 llvm::SmallVector<Type, 4> new_cluster_result_types;
390 Operation* cluster_terminator = cluster.GetBody().getTerminator();
391 for (auto result :
392 llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
393 Value cluster_terminator_operand = std::get<0>(result);
394 if (cluster_terminator_operand.getDefiningOp() &&
395 cluster.getOperation()->isProperAncestor(
396 cluster_terminator_operand.getDefiningOp())) {
397 new_cluster_results.push_back(cluster_terminator_operand);
398 new_cluster_result_types.push_back(cluster_terminator_operand.getType());
399 used_old_cluster_results.push_back(std::get<1>(result));
400 } else {
401 std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
402 }
403 }
404
405 if (new_cluster_results.size() == cluster.getNumResults()) return;
406
407 builder->setInsertionPoint(cluster);
408 auto new_cluster = builder->create<tf_device::ClusterOp>(
409 cluster.getLoc(), new_cluster_result_types,
410 /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
411 new_cluster.body().takeBody(cluster.body());
412 new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
413
414 for (auto result :
415 llvm::zip(used_old_cluster_results, new_cluster.getResults()))
416 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
417
418 cluster.erase();
419 }
420
421 struct TPUExtractHeadTailOutsideCompilationPass
422 : public TF::TPUExtractHeadTailOutsideCompilationPassBase<
423 TPUExtractHeadTailOutsideCompilationPass> {
424 void runOnOperation() override;
425 };
426
runOnOperation()427 void TPUExtractHeadTailOutsideCompilationPass::runOnOperation() {
428 auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
429 // Get runtime devices information from the closest parent module.
430 auto module = getOperation();
431 mlir::TF::RuntimeDevices devices;
432 if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
433 return signalPassFailure();
434
435 OpBuilder builder(&getContext());
436 llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
437 module.walk(
438 [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
439
440 for (tf_device::ClusterOp cluster : clusters) {
441 std::string host_device;
442 bool cluster_updated = false;
443 if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
444 devices, cluster, &host_device,
445 &cluster_updated)) ||
446 failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
447 devices, host_device, &cluster,
448 &cluster_updated)))
449 return signalPassFailure();
450 if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
451 }
452 }
453
454 } // anonymous namespace
455
456 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass()457 CreateTPUExtractHeadTailOutsideCompilationPass() {
458 return std::make_unique<TPUExtractHeadTailOutsideCompilationPass>();
459 }
460
461 } // namespace TFTPU
462 } // namespace mlir
463