1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <tuple>
18 #include <type_traits>
19 #include <utility>
20
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Block.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/IR/Visitors.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
46
47 namespace mlir {
48 namespace TFTPU {
49
50 // This pass extracts a CPU computation cluster with `_xla_outside_compilation`
51 // annotation from the head or tail of a TPU cluster.
52
53 namespace {
54
55 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
56
HasOutsideCompilationAttribute(Operation * op)57 bool HasOutsideCompilationAttribute(Operation* op) {
58 return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
59 }
60
61 // Finds op that created a given value. If the value is a BlockArgument, this
62 // returns the owner of the Block.
GetOpOfValue(Value value)63 Operation* GetOpOfValue(Value value) {
64 if (auto block_arg = value.dyn_cast<BlockArgument>())
65 return block_arg.getOwner()->getParentOp();
66
67 return value.getDefiningOp();
68 }
69
70 // Checks if `op` is nested in `block`.
OpInBlock(Operation * op,Block * block)71 bool OpInBlock(Operation* op, Block* block) {
72 Block* op_block = op->getBlock();
73 while (op_block) {
74 if (op_block == block) return true;
75 if (auto* parent_op = op_block->getParentOp()) {
76 op_block = parent_op->getBlock();
77 } else {
78 break;
79 }
80 }
81 return false;
82 }
83
84 // Wraps block in a Launch. External uses of ops in the block will be return
85 // values of the Launch and remapped to the Launch results. If `before` is set
86 // to true, the Launch is created before `op`. Otherwise the Launch is created
87 // after `op`.
CreateLaunchForBlock(OpBuilder * builder,Operation * op,bool before,Block * launch_block,llvm::StringRef host_device)88 tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
89 bool before, Block* launch_block,
90 llvm::StringRef host_device) {
91 // Find results and result types of ops in block that needs to returned.
92 llvm::SmallVector<Value, 4> launch_results;
93 llvm::SmallVector<Type, 4> launch_result_types;
94 for (Operation& head_outside_compiled_op : *launch_block) {
95 for (Value result : head_outside_compiled_op.getResults()) {
96 bool has_external_uses = false;
97 for (Operation* user : result.getUsers()) {
98 if (OpInBlock(user, launch_block)) continue;
99 has_external_uses = true;
100 break;
101 }
102 if (has_external_uses) {
103 launch_results.push_back(result);
104 launch_result_types.push_back(result.getType());
105 }
106 }
107 }
108
109 before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
110 auto launch = builder->create<tf_device::LaunchOp>(
111 op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
112 launch.body().push_back(launch_block);
113
114 builder->setInsertionPointToEnd(&launch.GetBody());
115 builder->create<tf_device::ReturnOp>(op->getLoc(), launch_results);
116
117 return launch;
118 }
119
120 // Checks if an operation is a supported TPU embedding op.
IsEmbeddingOp(Operation * op)121 bool IsEmbeddingOp(Operation* op) {
122 return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
123 TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
124 TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp,
125 TF::RecvTPUEmbeddingActivationsOp,
126 TF::SendTPUEmbeddingGradientsOp>(op);
127 }
128
129 // Returns a set of ops that are outside compiled and can be extracted to before
130 // the TPU computation. These ops are either connected to the inputs of the TPU
131 // computation or other ops that can be extracted, and have no operands from
132 // other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtHead(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster)133 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
134 const TF::SideEffectAnalysis& side_effect_analysis,
135 tf_device::ClusterOp cluster) {
136 const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
137 cluster->getParentOfType<func::FuncOp>());
138 Region* cluster_region = &cluster.body();
139 llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
140
141 auto cluster_ops = cluster.GetBody().without_terminator();
142 for (Operation& cluster_op : cluster_ops) {
143 if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
144 // An outside compiled op can be extracted if its operands are not from
145 // other ops in the cluster that cannot be extracted.
146
147 // Check if the side effecting op right before this side effecting op, if
148 // it is side effecting, can be head extracted. Because of op ordering due
149 // to side effects, if this is not true, this op cannot be head extracted.
150 // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
151 // is in a topological sort order and depending on that ordering, embedding
152 // ops may prevent other ops from being head extracted.
153 auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
154 if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) {
155 bool skip = false;
156 for (Operation* predecessor : llvm::reverse(predecessors)) {
157 if (IsEmbeddingOp(predecessor)) continue;
158 skip = !head_outside_compiled_ops.contains(predecessor);
159 break;
160 }
161 if (skip) continue;
162 }
163
164 auto walk_result = cluster_op.walk([&](Operation* op) {
165 for (Value operand : op->getOperands()) {
166 Operation* operand_op = GetOpOfValue(operand);
167 if (head_outside_compiled_ops.count(operand_op) ||
168 operand_op == &cluster_op)
169 continue;
170
171 if (operand_op->getParentRegion() == cluster_region)
172 return WalkResult::interrupt();
173 }
174 return WalkResult::advance();
175 });
176
177 if (!walk_result.wasInterrupted())
178 head_outside_compiled_ops.insert(&cluster_op);
179 }
180
181 return head_outside_compiled_ops.takeVector();
182 }
183
184 // Moves head outside compiled ops into its own `tf_device.LaunchOp`
185 // computation before the cluster.
CreateHeadComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > head_outside_compiled_ops,llvm::StringRef host_device)186 void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
187 llvm::ArrayRef<Operation*> head_outside_compiled_ops,
188 llvm::StringRef host_device) {
189 Block* launch_block = new Block;
190 for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
191 head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
192 head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
193 }
194
195 tf_device::LaunchOp launch = CreateLaunchForBlock(
196 builder, cluster, /*before=*/true, launch_block, host_device);
197
198 for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
199 launch.getResults()))
200 replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
201 cluster.body());
202 }
203
204 // Extracts and move outside compiled ops that have no dependencies in the
205 // cluster to before the cluster.
LiftHeadOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,tf_device::ClusterOp cluster,std::string * host_device,bool * cluster_updated)206 mlir::LogicalResult LiftHeadOutsideCompiledOps(
207 OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
208 const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
209 std::string* host_device, bool* cluster_updated) {
210 llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
211 FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
212 if (head_outside_compiled_ops.empty()) return success();
213 if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
214 host_device)))
215 return failure();
216
217 CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
218 *host_device);
219
220 *cluster_updated = true;
221 return success();
222 }
223
224 // Fills `tail_outside_compiled_ops` with ops that are outside compiled and
225 // can be extracted to after the TPU computation, and `cluster_results` with new
226 // results of the cluster. These ops are either connected to the output of the
227 // TPU computation or other ops that can be extracted, and have no results used
228 // by other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtTailAndClusterResults(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster,llvm::SmallVectorImpl<Operation * > * tail_outside_compiled_ops,llvm::SmallVectorImpl<Value> * cluster_results)229 void FindOutsideCompiledOpsAtTailAndClusterResults(
230 const TF::SideEffectAnalysis& side_effect_analysis,
231 tf_device::ClusterOp cluster,
232 llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
233 llvm::SmallVectorImpl<Value>* cluster_results) {
234 const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
235 cluster->getParentOfType<func::FuncOp>());
236 Region* cluster_region = &cluster.body();
237 llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
238 Operation* terminator = cluster.GetBody().getTerminator();
239 llvm::SmallSetVector<Value, 4> cluster_results_set;
240 cluster_results_set.insert(terminator->getOperands().begin(),
241 terminator->getOperands().end());
242
243 auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
244 for (Operation& cluster_op : cluster_ops) {
245 if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
246
247 // Check if the side effecting op right after this side effecting op, if
248 // it is side effecting, can be tail extracted. Because of op ordering due
249 // to side effects, if this is not true, this op cannot be tail extracted.
250 // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
251 // is in a topological sort order and depending on that ordering, embedding
252 // ops may prevent other ops from being tail extracted.
253 auto successors = analysis.DirectControlSuccessors(
254 &cluster_op, [&terminator](Operation* op) { return op != terminator; });
255 if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) {
256 bool skip = false;
257 for (Operation* successor : successors) {
258 if (IsEmbeddingOp(successor)) continue;
259 skip = !tail_outside_compiled_ops_set.contains(successor);
260 break;
261 }
262 if (skip) continue;
263 }
264
265 llvm::SmallVector<int, 4> results_to_forward;
266 bool can_be_extracted =
267 llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
268 return op == terminator || tail_outside_compiled_ops_set.count(op);
269 });
270 if (!can_be_extracted) continue;
271
272 // Collect operands of cluster op that are generated within the cluster.
273 // These values should be returned by the cluster.
274 cluster_op.walk([&](Operation* op) {
275 for (Value operand : op->getOperands()) {
276 Operation* operand_op = GetOpOfValue(operand);
277 if (operand_op->getParentRegion() == cluster_region)
278 cluster_results_set.insert(operand);
279 }
280 });
281
282 // Remove results of op to be extracted as there are no uses in the cluster.
283 for (Value result : cluster_op.getResults())
284 cluster_results_set.remove(result);
285 // Insert all ops including nested ops for checking outputs/side effects.
286 cluster_op.walk(
287 [&](Operation* op) { tail_outside_compiled_ops_set.insert(op); });
288
289 // Only add top level ops to output vector.
290 tail_outside_compiled_ops->push_back(&cluster_op);
291 }
292
293 *cluster_results = cluster_results_set.takeVector();
294 }
295
296 // Moves tail outside compiled ops into its own `tf_device.LaunchOp`
297 // computation after the cluster.
CreateTailComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > tail_outside_compiled_ops,llvm::StringRef host_device)298 void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
299 llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
300 llvm::StringRef host_device) {
301 Block* launch_block = new Block;
302 for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
303 tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
304 tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
305 }
306
307 tf_device::LaunchOp launch = CreateLaunchForBlock(
308 builder, cluster, /*before=*/false, launch_block, host_device);
309
310 auto operand_not_in_launch = [&](OpOperand& operand) {
311 return !launch.getOperation()->isProperAncestor(operand.getOwner());
312 };
313 for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
314 launch.getResults()))
315 std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
316 operand_not_in_launch);
317 }
318
319 // Updates cluster with updated cluster results after extracting tail outside
320 // compiled ops.
UpdateClusterResults(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Value> new_cluster_results)321 tf_device::ClusterOp UpdateClusterResults(
322 OpBuilder* builder, tf_device::ClusterOp cluster,
323 llvm::ArrayRef<Value> new_cluster_results) {
324 Operation* old_terminator = cluster.GetBody().getTerminator();
325 builder->setInsertionPoint(old_terminator);
326 builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
327 new_cluster_results);
328 old_terminator->erase();
329
330 builder->setInsertionPoint(cluster);
331 llvm::SmallVector<Type, 4> new_cluster_result_types;
332 new_cluster_result_types.reserve(new_cluster_results.size());
333 for (const auto& new_cluster_result : new_cluster_results)
334 new_cluster_result_types.push_back(new_cluster_result.getType());
335
336 auto new_cluster = builder->create<tf_device::ClusterOp>(
337 cluster.getLoc(), new_cluster_result_types,
338 /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
339 new_cluster.body().takeBody(cluster.body());
340
341 auto operand_not_in_cluster = [&](OpOperand& operand) {
342 return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
343 };
344 for (auto result :
345 llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
346 new_cluster.getResults()))
347 std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
348 operand_not_in_cluster);
349
350 cluster.erase();
351 return new_cluster;
352 }
353
354 // Extracts and move outside compiled ops that do not create dependencies in the
355 // cluster to after the cluster.
LiftTailOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,std::string host_device,tf_device::ClusterOp * cluster,bool * cluster_updated)356 mlir::LogicalResult LiftTailOutsideCompiledOps(
357 OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
358 const mlir::TF::RuntimeDevices& devices, std::string host_device,
359 tf_device::ClusterOp* cluster, bool* cluster_updated) {
360 llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
361 llvm::SmallVector<Value, 4> cluster_results;
362 FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
363 &tail_outside_compiled_ops,
364 &cluster_results);
365 if (tail_outside_compiled_ops.empty()) return success();
366
367 if (host_device.empty())
368 if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
369 &host_device)))
370 return failure();
371
372 // Forward all results of cluster first. These results will be remapped once
373 // a new cluster is formed.
374 cluster->replaceAllUsesWith(
375 cluster->GetBody().getTerminator()->getOperands());
376
377 CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
378 host_device);
379
380 *cluster = UpdateClusterResults(builder, *cluster, cluster_results);
381
382 *cluster_updated = true;
383 return success();
384 }
385
386 // Removes aliased outputs in cluster from ops outside of cluster.
RemoveClusterAliasedOutputs(OpBuilder * builder,tf_device::ClusterOp cluster)387 void RemoveClusterAliasedOutputs(OpBuilder* builder,
388 tf_device::ClusterOp cluster) {
389 llvm::SmallVector<Value, 4> used_old_cluster_results;
390 llvm::SmallVector<Value, 4> new_cluster_results;
391 llvm::SmallVector<Type, 4> new_cluster_result_types;
392 Operation* cluster_terminator = cluster.GetBody().getTerminator();
393 for (auto result :
394 llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
395 Value cluster_terminator_operand = std::get<0>(result);
396 if (cluster_terminator_operand.getDefiningOp() &&
397 cluster.getOperation()->isProperAncestor(
398 cluster_terminator_operand.getDefiningOp())) {
399 new_cluster_results.push_back(cluster_terminator_operand);
400 new_cluster_result_types.push_back(cluster_terminator_operand.getType());
401 used_old_cluster_results.push_back(std::get<1>(result));
402 } else {
403 std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
404 }
405 }
406
407 if (new_cluster_results.size() == cluster.getNumResults()) return;
408
409 builder->setInsertionPoint(cluster);
410 auto new_cluster = builder->create<tf_device::ClusterOp>(
411 cluster.getLoc(), new_cluster_result_types,
412 /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
413 new_cluster.body().takeBody(cluster.body());
414 new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
415
416 for (auto result :
417 llvm::zip(used_old_cluster_results, new_cluster.getResults()))
418 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
419
420 cluster.erase();
421 }
422
423 // Checks if `type` is allowed for data on TPUs. String and resources cannot be
424 // assigned to TPUs. There are other TF types that are not allowed on TPUs, but
425 // these will be removed by successive passes in TF/XLA bridge phase 2.
TypeValidForTPU(Type type)426 bool TypeValidForTPU(Type type) {
427 Type elem = getElementTypeOrSelf(type);
428 return !elem.isa<TF::ResourceType>() && !elem.isa<TF::StringType>();
429 }
430
431 // Check that cluster results are valid. An result is invalid when it does not
432 // have a valid XLA type.
CheckClusterResults(tf_device::ClusterOp cluster)433 LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) {
434 for (OpResult result : cluster.getResults()) {
435 if (!TypeValidForTPU(result.getType())) {
436 cluster.emitError()
437 << "The TPUExtractHeadTailOutsideCompilation pass produced a TPU "
438 "cluster with a result with a non-XLA type: "
439 << result.getType();
440 return failure();
441 }
442 }
443 return success();
444 }
445
446 // Check the validity of the module, post-pass.
CheckPostconditions(ModuleOp module)447 LogicalResult CheckPostconditions(ModuleOp module) {
448 auto walk_result = module.walk([&](tf_device::ClusterOp cluster) {
449 if (failed(CheckClusterResults(cluster))) return WalkResult::interrupt();
450 return WalkResult::advance();
451 });
452 if (walk_result.wasInterrupted()) return failure();
453 return success();
454 }
455
456 struct TPUExtractHeadTailOutsideCompilationPass
457 : public TF::TPUExtractHeadTailOutsideCompilationPassBase<
458 TPUExtractHeadTailOutsideCompilationPass> {
459 void runOnOperation() override;
460 };
461
runOnOperation()462 void TPUExtractHeadTailOutsideCompilationPass::runOnOperation() {
463 auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
464 // Get runtime devices information from the closest parent module.
465 auto module = getOperation();
466 mlir::TF::RuntimeDevices devices;
467 if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
468 return signalPassFailure();
469
470 OpBuilder builder(&getContext());
471 llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
472 module.walk(
473 [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
474
475 for (tf_device::ClusterOp cluster : clusters) {
476 std::string host_device;
477 bool cluster_updated = false;
478 if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
479 devices, cluster, &host_device,
480 &cluster_updated)) ||
481 failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
482 devices, host_device, &cluster,
483 &cluster_updated)))
484 return signalPassFailure();
485 if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
486 }
487
488 if (failed(CheckPostconditions(module))) return signalPassFailure();
489 }
490
491 } // anonymous namespace
492
493 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass()494 CreateTPUExtractHeadTailOutsideCompilationPass() {
495 return std::make_unique<TPUExtractHeadTailOutsideCompilationPass>();
496 }
497
498 } // namespace TFTPU
499 } // namespace mlir
500