• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <tuple>
18 #include <type_traits>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Block.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/IR/Visitors.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
45 
46 namespace mlir {
47 namespace TFTPU {
48 
49 // This pass extracts a CPU computation cluster with `_xla_outside_compilation`
50 // annotation from the head or tail of a TPU cluster.
51 
52 namespace {
53 
54 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
55 
HasOutsideCompilationAttribute(Operation * op)56 bool HasOutsideCompilationAttribute(Operation* op) {
57   return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
58 }
59 
60 // Finds op that created a given value. If the value is a BlockArgument, this
61 // returns the owner of the Block.
GetOpOfValue(Value value)62 Operation* GetOpOfValue(Value value) {
63   if (auto block_arg = value.dyn_cast<BlockArgument>())
64     return block_arg.getOwner()->getParentOp();
65 
66   return value.getDefiningOp();
67 }
68 
69 // Checks if `op` is nested in `block`.
OpInBlock(Operation * op,Block * block)70 bool OpInBlock(Operation* op, Block* block) {
71   Block* op_block = op->getBlock();
72   while (op_block) {
73     if (op_block == block) return true;
74     if (auto* parent_op = op_block->getParentOp()) {
75       op_block = parent_op->getBlock();
76     } else {
77       break;
78     }
79   }
80   return false;
81 }
82 
83 // Wraps block in a Launch. External uses of ops in the block will be return
84 // values of the Launch and remapped to the Launch results. If `before` is set
85 // to true, the Launch is created before `op`. Otherwise the Launch is created
86 // after `op`.
CreateLaunchForBlock(OpBuilder * builder,Operation * op,bool before,Block * launch_block,llvm::StringRef host_device)87 tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
88                                          bool before, Block* launch_block,
89                                          llvm::StringRef host_device) {
90   // Find results and result types of ops in block that needs to returned.
91   llvm::SmallVector<Value, 4> launch_results;
92   llvm::SmallVector<Type, 4> launch_result_types;
93   for (Operation& head_outside_compiled_op : *launch_block) {
94     for (Value result : head_outside_compiled_op.getResults()) {
95       bool has_external_uses = false;
96       for (Operation* user : result.getUsers()) {
97         if (OpInBlock(user, launch_block)) continue;
98         has_external_uses = true;
99         break;
100       }
101       if (has_external_uses) {
102         launch_results.push_back(result);
103         launch_result_types.push_back(result.getType());
104       }
105     }
106   }
107 
108   before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
109   auto launch = builder->create<tf_device::LaunchOp>(
110       op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
111   launch.body().push_back(launch_block);
112 
113   builder->setInsertionPointToEnd(&launch.GetBody());
114   builder->create<tf_device::ReturnOp>(op->getLoc(), launch_results);
115 
116   return launch;
117 }
118 
119 // Checks if an operation is a supported TPU embedding op.
IsEmbeddingOp(Operation * op)120 bool IsEmbeddingOp(Operation* op) {
121   return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
122              TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
123              TF::RecvTPUEmbeddingActivationsOp,
124              TF::SendTPUEmbeddingGradientsOp>(op);
125 }
126 
127 // Returns a set of ops that are outside compiled and can be extracted to before
128 // the TPU computation. These ops are either connected to the inputs of the TPU
129 // computation or other ops that can be extracted, and have no operands from
130 // other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtHead(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster)131 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
132     const TF::SideEffectAnalysis& side_effect_analysis,
133     tf_device::ClusterOp cluster) {
134   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
135       cluster->getParentOfType<FuncOp>());
136   Region* cluster_region = &cluster.body();
137   llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
138 
139   auto cluster_ops = cluster.GetBody().without_terminator();
140   for (Operation& cluster_op : cluster_ops) {
141     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
142     // An outside compiled op can be extracted if its operands are not from
143     // other ops in the cluster that cannot be extracted.
144 
145     // Check if the side effecting op right before this side effecting op, if
146     // it is side effecting, can be head extracted. Because of op ordering due
147     // to side effects, if this is not true, this op cannot be head extracted.
148     // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
149     // is in a topological sort order and depending on that ordering, embedding
150     // ops may prevent other ops from being head extracted.
151     auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
152     if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) {
153       bool skip = false;
154       for (Operation* predecessor : llvm::reverse(predecessors)) {
155         if (IsEmbeddingOp(predecessor)) continue;
156         skip = !head_outside_compiled_ops.contains(predecessor);
157         break;
158       }
159       if (skip) continue;
160     }
161 
162     auto walk_result = cluster_op.walk([&](Operation* op) {
163       for (Value operand : op->getOperands()) {
164         Operation* operand_op = GetOpOfValue(operand);
165         if (head_outside_compiled_ops.count(operand_op) ||
166             operand_op == &cluster_op)
167           continue;
168 
169         if (operand_op->getParentRegion() == cluster_region)
170           return WalkResult::interrupt();
171       }
172       return WalkResult::advance();
173     });
174 
175     if (!walk_result.wasInterrupted())
176       head_outside_compiled_ops.insert(&cluster_op);
177   }
178 
179   return head_outside_compiled_ops.takeVector();
180 }
181 
182 // Moves head outside compiled ops into its own `tf_device.LaunchOp`
183 // computation before the cluster.
CreateHeadComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > head_outside_compiled_ops,llvm::StringRef host_device)184 void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
185                            llvm::ArrayRef<Operation*> head_outside_compiled_ops,
186                            llvm::StringRef host_device) {
187   Block* launch_block = new Block;
188   for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
189     head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
190     head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
191   }
192 
193   tf_device::LaunchOp launch = CreateLaunchForBlock(
194       builder, cluster, /*before=*/true, launch_block, host_device);
195 
196   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
197                                launch.getResults()))
198     replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
199                                cluster.body());
200 }
201 
202 // Extracts and move outside compiled ops that have no dependencies in the
203 // cluster to before the cluster.
LiftHeadOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,tf_device::ClusterOp cluster,std::string * host_device,bool * cluster_updated)204 mlir::LogicalResult LiftHeadOutsideCompiledOps(
205     OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
206     const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
207     std::string* host_device, bool* cluster_updated) {
208   llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
209       FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
210   if (head_outside_compiled_ops.empty()) return success();
211   if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
212                                                          host_device)))
213     return failure();
214 
215   CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
216                         *host_device);
217 
218   *cluster_updated = true;
219   return success();
220 }
221 
222 // Fills `tail_outside_compiled_ops` with ops that are outside compiled and
223 // can be extracted to after the TPU computation, and `cluster_results` with new
224 // results of the cluster. These ops are either connected to the output of the
225 // TPU computation or other ops that can be extracted, and have no results used
226 // by other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtTailAndClusterResults(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster,llvm::SmallVectorImpl<Operation * > * tail_outside_compiled_ops,llvm::SmallVectorImpl<Value> * cluster_results)227 void FindOutsideCompiledOpsAtTailAndClusterResults(
228     const TF::SideEffectAnalysis& side_effect_analysis,
229     tf_device::ClusterOp cluster,
230     llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
231     llvm::SmallVectorImpl<Value>* cluster_results) {
232   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
233       cluster->getParentOfType<FuncOp>());
234   Region* cluster_region = &cluster.body();
235   llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
236   Operation* terminator = cluster.GetBody().getTerminator();
237   llvm::SmallSetVector<Value, 4> cluster_results_set;
238   cluster_results_set.insert(terminator->getOperands().begin(),
239                              terminator->getOperands().end());
240 
241   auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
242   for (Operation& cluster_op : cluster_ops) {
243     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
244 
245     // Check if the side effecting op right after this side effecting op, if
246     // it is side effecting, can be tail extracted. Because of op ordering due
247     // to side effects, if this is not true, this op cannot be tail extracted.
248     // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
249     // is in a topological sort order and depending on that ordering, embedding
250     // ops may prevent other ops from being tail extracted.
251     auto successors = analysis.DirectControlSuccessors(
252         &cluster_op, [&terminator](Operation* op) { return op != terminator; });
253     if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) {
254       bool skip = false;
255       for (Operation* successor : successors) {
256         if (IsEmbeddingOp(successor)) continue;
257         skip = !tail_outside_compiled_ops_set.contains(successor);
258         break;
259       }
260       if (skip) continue;
261     }
262 
263     llvm::SmallVector<int, 4> results_to_forward;
264     bool can_be_extracted =
265         llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
266           return op == terminator || tail_outside_compiled_ops_set.count(op);
267         });
268     if (!can_be_extracted) continue;
269 
270     // Collect operands of cluster op that are generated within the cluster.
271     // These values should be returned by the cluster.
272     cluster_op.walk([&](Operation* op) {
273       for (Value operand : op->getOperands()) {
274         Operation* operand_op = GetOpOfValue(operand);
275         if (operand_op->getParentRegion() == cluster_region)
276           cluster_results_set.insert(operand);
277       }
278     });
279 
280     // Remove results of op to be extracted as there are no uses in the cluster.
281     for (Value result : cluster_op.getResults())
282       cluster_results_set.remove(result);
283     // Insert all ops including nested ops for checking outputs/side effects.
284     cluster_op.walk(
285         [&](Operation* op) { tail_outside_compiled_ops_set.insert(op); });
286 
287     // Only add top level ops to output vector.
288     tail_outside_compiled_ops->push_back(&cluster_op);
289   }
290 
291   *cluster_results = cluster_results_set.takeVector();
292 }
293 
294 // Moves tail outside compiled ops into its own `tf_device.LaunchOp`
295 // computation after the cluster.
CreateTailComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > tail_outside_compiled_ops,llvm::StringRef host_device)296 void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
297                            llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
298                            llvm::StringRef host_device) {
299   Block* launch_block = new Block;
300   for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
301     tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
302     tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
303   }
304 
305   tf_device::LaunchOp launch = CreateLaunchForBlock(
306       builder, cluster, /*before=*/false, launch_block, host_device);
307 
308   auto operand_not_in_launch = [&](OpOperand& operand) {
309     return !launch.getOperation()->isProperAncestor(operand.getOwner());
310   };
311   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
312                                launch.getResults()))
313     std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
314                                           operand_not_in_launch);
315 }
316 
317 // Updates cluster with updated cluster results after extracting tail outside
318 // compiled ops.
UpdateClusterResults(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Value> new_cluster_results)319 tf_device::ClusterOp UpdateClusterResults(
320     OpBuilder* builder, tf_device::ClusterOp cluster,
321     llvm::ArrayRef<Value> new_cluster_results) {
322   Operation* old_terminator = cluster.GetBody().getTerminator();
323   builder->setInsertionPoint(old_terminator);
324   builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
325                                        new_cluster_results);
326   old_terminator->erase();
327 
328   builder->setInsertionPoint(cluster);
329   llvm::SmallVector<Type, 4> new_cluster_result_types;
330   new_cluster_result_types.reserve(new_cluster_results.size());
331   for (const auto& new_cluster_result : new_cluster_results)
332     new_cluster_result_types.push_back(new_cluster_result.getType());
333 
334   auto new_cluster = builder->create<tf_device::ClusterOp>(
335       cluster.getLoc(), new_cluster_result_types,
336       /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
337   new_cluster.body().takeBody(cluster.body());
338 
339   auto operand_not_in_cluster = [&](OpOperand& operand) {
340     return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
341   };
342   for (auto result :
343        llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
344                  new_cluster.getResults()))
345     std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
346                                           operand_not_in_cluster);
347 
348   cluster.erase();
349   return new_cluster;
350 }
351 
352 // Extracts and move outside compiled ops that do not create dependencies in the
353 // cluster to after the cluster.
LiftTailOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,std::string host_device,tf_device::ClusterOp * cluster,bool * cluster_updated)354 mlir::LogicalResult LiftTailOutsideCompiledOps(
355     OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
356     const mlir::TF::RuntimeDevices& devices, std::string host_device,
357     tf_device::ClusterOp* cluster, bool* cluster_updated) {
358   llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
359   llvm::SmallVector<Value, 4> cluster_results;
360   FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
361                                                 &tail_outside_compiled_ops,
362                                                 &cluster_results);
363   if (tail_outside_compiled_ops.empty()) return success();
364 
365   if (host_device.empty())
366     if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
367                                                            &host_device)))
368       return failure();
369 
370   // Forward all results of cluster first. These results will be remapped once
371   // a new cluster is formed.
372   cluster->replaceAllUsesWith(
373       cluster->GetBody().getTerminator()->getOperands());
374 
375   CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
376                         host_device);
377 
378   *cluster = UpdateClusterResults(builder, *cluster, cluster_results);
379 
380   *cluster_updated = true;
381   return success();
382 }
383 
384 // Removes aliased outputs in cluster from ops outside of cluster.
RemoveClusterAliasedOutputs(OpBuilder * builder,tf_device::ClusterOp cluster)385 void RemoveClusterAliasedOutputs(OpBuilder* builder,
386                                  tf_device::ClusterOp cluster) {
387   llvm::SmallVector<Value, 4> used_old_cluster_results;
388   llvm::SmallVector<Value, 4> new_cluster_results;
389   llvm::SmallVector<Type, 4> new_cluster_result_types;
390   Operation* cluster_terminator = cluster.GetBody().getTerminator();
391   for (auto result :
392        llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
393     Value cluster_terminator_operand = std::get<0>(result);
394     if (cluster_terminator_operand.getDefiningOp() &&
395         cluster.getOperation()->isProperAncestor(
396             cluster_terminator_operand.getDefiningOp())) {
397       new_cluster_results.push_back(cluster_terminator_operand);
398       new_cluster_result_types.push_back(cluster_terminator_operand.getType());
399       used_old_cluster_results.push_back(std::get<1>(result));
400     } else {
401       std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
402     }
403   }
404 
405   if (new_cluster_results.size() == cluster.getNumResults()) return;
406 
407   builder->setInsertionPoint(cluster);
408   auto new_cluster = builder->create<tf_device::ClusterOp>(
409       cluster.getLoc(), new_cluster_result_types,
410       /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
411   new_cluster.body().takeBody(cluster.body());
412   new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
413 
414   for (auto result :
415        llvm::zip(used_old_cluster_results, new_cluster.getResults()))
416     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
417 
418   cluster.erase();
419 }
420 
421 struct TPUExtractHeadTailOutsideCompilationPass
422     : public TF::TPUExtractHeadTailOutsideCompilationPassBase<
423           TPUExtractHeadTailOutsideCompilationPass> {
424   void runOnOperation() override;
425 };
426 
runOnOperation()427 void TPUExtractHeadTailOutsideCompilationPass::runOnOperation() {
428   auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
429   // Get runtime devices information from the closest parent module.
430   auto module = getOperation();
431   mlir::TF::RuntimeDevices devices;
432   if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
433     return signalPassFailure();
434 
435   OpBuilder builder(&getContext());
436   llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
437   module.walk(
438       [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
439 
440   for (tf_device::ClusterOp cluster : clusters) {
441     std::string host_device;
442     bool cluster_updated = false;
443     if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
444                                           devices, cluster, &host_device,
445                                           &cluster_updated)) ||
446         failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
447                                           devices, host_device, &cluster,
448                                           &cluster_updated)))
449       return signalPassFailure();
450     if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
451   }
452 }
453 
454 }  // anonymous namespace
455 
456 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass()457 CreateTPUExtractHeadTailOutsideCompilationPass() {
458   return std::make_unique<TPUExtractHeadTailOutsideCompilationPass>();
459 }
460 
461 }  // namespace TFTPU
462 }  // namespace mlir
463