• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <tuple>
18 #include <type_traits>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Block.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/IR/Visitors.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
44 
45 namespace mlir {
46 namespace TFTPU {
47 
48 // This pass extracts a CPU computation cluster with `_xla_outside_compilation`
49 // annotation from the head or tail of a TPU cluster.
50 
51 namespace {
52 
53 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
54 
HasOutsideCompilationAttribute(Operation * op)55 bool HasOutsideCompilationAttribute(Operation* op) {
56   return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
57 }
58 
59 // Finds op that created a given value. If the value is a BlockArgument, this
60 // returns the owner of the Block.
GetOpOfValue(Value value)61 Operation* GetOpOfValue(Value value) {
62   if (auto block_arg = value.dyn_cast<BlockArgument>())
63     return block_arg.getOwner()->getParentOp();
64 
65   return value.getDefiningOp();
66 }
67 
68 // Checks if `op` is nested in `block`.
OpInBlock(Operation * op,Block * block)69 bool OpInBlock(Operation* op, Block* block) {
70   Block* op_block = op->getBlock();
71   while (op_block) {
72     if (op_block == block) return true;
73     if (auto* parent_op = op_block->getParentOp()) {
74       op_block = parent_op->getBlock();
75     } else {
76       break;
77     }
78   }
79   return false;
80 }
81 
82 // Wraps block in a Launch. External uses of ops in the block will be return
83 // values of the Launch and remapped to the Launch results. If `before` is set
84 // to true, the Launch is created before `op`. Otherwise the Launch is created
85 // after `op`.
CreateLaunchForBlock(OpBuilder * builder,Operation * op,bool before,Block * launch_block,llvm::StringRef host_device)86 tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
87                                          bool before, Block* launch_block,
88                                          llvm::StringRef host_device) {
89   // Find results and result types of ops in block that needs to returned.
90   llvm::SmallVector<Value, 4> launch_results;
91   llvm::SmallVector<Type, 4> launch_result_types;
92   for (Operation& head_outside_compiled_op : *launch_block) {
93     for (Value result : head_outside_compiled_op.getResults()) {
94       bool has_external_uses = false;
95       for (Operation* user : result.getUsers()) {
96         if (OpInBlock(user, launch_block)) continue;
97         has_external_uses = true;
98         break;
99       }
100       if (has_external_uses) {
101         launch_results.push_back(result);
102         launch_result_types.push_back(result.getType());
103       }
104     }
105   }
106 
107   before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
108   auto launch = builder->create<tf_device::LaunchOp>(
109       op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
110   launch.body().push_back(launch_block);
111 
112   builder->setInsertionPointToEnd(&launch.GetBody());
113   builder->create<tf_device::ReturnOp>(op->getLoc(), launch_results);
114 
115   return launch;
116 }
117 
118 // Checks if an operation is a supported TPU embedding op.
IsEmbeddingOp(Operation * op)119 bool IsEmbeddingOp(Operation* op) {
120   return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
121              TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
122              TF::RecvTPUEmbeddingActivationsOp,
123              TF::SendTPUEmbeddingGradientsOp>(op);
124 }
125 
126 // Returns a set of ops that are outside compiled and can be extracted to before
127 // the TPU computation. These ops are either connected to the inputs of the TPU
128 // computation or other ops that can be extracted, and have no operands from
129 // other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtHead(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster)130 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
131     const TF::SideEffectAnalysis& side_effect_analysis,
132     tf_device::ClusterOp cluster) {
133   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
134       cluster->getParentOfType<FuncOp>());
135   Region* cluster_region = &cluster.body();
136   llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
137 
138   auto cluster_ops = cluster.GetBody().without_terminator();
139   for (Operation& cluster_op : cluster_ops) {
140     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
141     // An outside compiled op can be extracted if its operands are not from
142     // other ops in the cluster that cannot be extracted.
143 
144     // Check if the side effecting op right before this side effecting op, if
145     // it is side effecting, can be head extracted. Because of op ordering due
146     // to side effects, if this is not true, this op cannot be head extracted.
147     // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
148     // is in a topological sort order and depending on that ordering, embedding
149     // ops may prevent other ops from being head extracted.
150     auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
151     if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) {
152       bool skip = false;
153       for (Operation* predecessor : llvm::reverse(predecessors)) {
154         if (IsEmbeddingOp(predecessor)) continue;
155         skip = !head_outside_compiled_ops.contains(predecessor);
156         break;
157       }
158       if (skip) continue;
159     }
160 
161     auto walk_result = cluster_op.walk([&](Operation* op) {
162       for (Value operand : op->getOperands()) {
163         Operation* operand_op = GetOpOfValue(operand);
164         if (head_outside_compiled_ops.count(operand_op)) continue;
165 
166         if (operand_op->getParentRegion() == cluster_region)
167           return WalkResult::interrupt();
168       }
169       return WalkResult::advance();
170     });
171 
172     if (!walk_result.wasInterrupted())
173       head_outside_compiled_ops.insert(&cluster_op);
174   }
175 
176   return head_outside_compiled_ops.takeVector();
177 }
178 
179 // Moves head outside compiled ops into its own `tf_device.LaunchOp`
180 // computation before the cluster.
CreateHeadComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > head_outside_compiled_ops,llvm::StringRef host_device)181 void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
182                            llvm::ArrayRef<Operation*> head_outside_compiled_ops,
183                            llvm::StringRef host_device) {
184   Block* launch_block = new Block;
185   for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
186     head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
187     head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
188   }
189 
190   tf_device::LaunchOp launch = CreateLaunchForBlock(
191       builder, cluster, /*before=*/true, launch_block, host_device);
192 
193   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
194                                launch.getResults()))
195     replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
196                                cluster.body());
197 }
198 
199 // Extracts and move outside compiled ops that have no dependencies in the
200 // cluster to before the cluster.
LiftHeadOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,tf_device::ClusterOp cluster,std::string * host_device,bool * cluster_updated)201 mlir::LogicalResult LiftHeadOutsideCompiledOps(
202     OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
203     const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
204     std::string* host_device, bool* cluster_updated) {
205   llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
206       FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
207   if (head_outside_compiled_ops.empty()) return success();
208   if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
209                                                          host_device)))
210     return failure();
211 
212   CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
213                         *host_device);
214 
215   *cluster_updated = true;
216   return success();
217 }
218 
219 // Fills `tail_outside_compiled_ops` with ops that are outside compiled and
220 // can be extracted to after the TPU computation, and `cluster_results` with new
221 // results of the cluster. These ops are either connected to the output of the
222 // TPU computation or other ops that can be extracted, and have no results used
223 // by other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtTailAndClusterResults(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster,llvm::SmallVectorImpl<Operation * > * tail_outside_compiled_ops,llvm::SmallVectorImpl<Value> * cluster_results)224 void FindOutsideCompiledOpsAtTailAndClusterResults(
225     const TF::SideEffectAnalysis& side_effect_analysis,
226     tf_device::ClusterOp cluster,
227     llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
228     llvm::SmallVectorImpl<Value>* cluster_results) {
229   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
230       cluster->getParentOfType<FuncOp>());
231   Region* cluster_region = &cluster.body();
232   llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
233   Operation* terminator = cluster.GetBody().getTerminator();
234   llvm::SmallSetVector<Value, 4> cluster_results_set;
235   cluster_results_set.insert(terminator->getOperands().begin(),
236                              terminator->getOperands().end());
237 
238   auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
239   for (Operation& cluster_op : cluster_ops) {
240     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
241 
242     // Check if the side effecting op right after this side effecting op, if
243     // it is side effecting, can be tail extracted. Because of op ordering due
244     // to side effects, if this is not true, this op cannot be tail extracted.
245     // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
246     // is in a topological sort order and depending on that ordering, embedding
247     // ops may prevent other ops from being tail extracted.
248     auto successors = analysis.DirectControlSuccessors(
249         &cluster_op, [&terminator](Operation* op) { return op != terminator; });
250     if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) {
251       bool skip = false;
252       for (Operation* successor : successors) {
253         if (IsEmbeddingOp(successor)) continue;
254         skip = !tail_outside_compiled_ops_set.contains(successor);
255         break;
256       }
257       if (skip) continue;
258     }
259 
260     llvm::SmallVector<int, 4> results_to_forward;
261     bool can_be_extracted =
262         llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
263           return op == terminator || tail_outside_compiled_ops_set.count(op);
264         });
265     if (!can_be_extracted) continue;
266 
267     // Collect operands of cluster op that are generated within the cluster.
268     // These values should be returned by the cluster.
269     cluster_op.walk([&](Operation* op) {
270       for (Value operand : op->getOperands()) {
271         Operation* operand_op = GetOpOfValue(operand);
272         if (operand_op->getParentRegion() == cluster_region)
273           cluster_results_set.insert(operand);
274       }
275     });
276 
277     // Remove results of op to be extracted as there are no uses in the cluster.
278     for (Value result : cluster_op.getResults())
279       cluster_results_set.remove(result);
280     tail_outside_compiled_ops_set.insert(&cluster_op);
281   }
282 
283   *tail_outside_compiled_ops = tail_outside_compiled_ops_set.takeVector();
284   *cluster_results = cluster_results_set.takeVector();
285 }
286 
287 // Moves tail outside compiled ops into its own `tf_device.LaunchOp`
288 // computation after the cluster.
CreateTailComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > tail_outside_compiled_ops,llvm::StringRef host_device)289 void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
290                            llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
291                            llvm::StringRef host_device) {
292   Block* launch_block = new Block;
293   for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
294     tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
295     tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
296   }
297 
298   tf_device::LaunchOp launch = CreateLaunchForBlock(
299       builder, cluster, /*before=*/false, launch_block, host_device);
300 
301   auto operand_not_in_launch = [&](OpOperand& operand) {
302     return !launch.getOperation()->isProperAncestor(operand.getOwner());
303   };
304   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
305                                launch.getResults()))
306     std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
307                                           operand_not_in_launch);
308 }
309 
310 // Updates cluster with updated cluster results after extracting tail outside
311 // compiled ops.
UpdateClusterResults(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Value> new_cluster_results)312 tf_device::ClusterOp UpdateClusterResults(
313     OpBuilder* builder, tf_device::ClusterOp cluster,
314     llvm::ArrayRef<Value> new_cluster_results) {
315   Operation* old_terminator = cluster.GetBody().getTerminator();
316   builder->setInsertionPoint(old_terminator);
317   builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
318                                        new_cluster_results);
319   old_terminator->erase();
320 
321   builder->setInsertionPoint(cluster);
322   llvm::SmallVector<Type, 4> new_cluster_result_types;
323   new_cluster_result_types.reserve(new_cluster_results.size());
324   for (const auto& new_cluster_result : new_cluster_results)
325     new_cluster_result_types.push_back(new_cluster_result.getType());
326 
327   auto new_cluster = builder->create<tf_device::ClusterOp>(
328       cluster.getLoc(), new_cluster_result_types,
329       /*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
330   new_cluster.body().takeBody(cluster.body());
331 
332   auto operand_not_in_cluster = [&](OpOperand& operand) {
333     return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
334   };
335   for (auto result :
336        llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
337                  new_cluster.getResults()))
338     std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
339                                           operand_not_in_cluster);
340 
341   cluster.erase();
342   return new_cluster;
343 }
344 
345 // Extracts and move outside compiled ops that do not create dependencies in the
346 // cluster to after the cluster.
LiftTailOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,std::string host_device,tf_device::ClusterOp * cluster,bool * cluster_updated)347 mlir::LogicalResult LiftTailOutsideCompiledOps(
348     OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
349     const mlir::TF::RuntimeDevices& devices, std::string host_device,
350     tf_device::ClusterOp* cluster, bool* cluster_updated) {
351   llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
352   llvm::SmallVector<Value, 4> cluster_results;
353   FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
354                                                 &tail_outside_compiled_ops,
355                                                 &cluster_results);
356   if (tail_outside_compiled_ops.empty()) return success();
357 
358   if (host_device.empty())
359     if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
360                                                            &host_device)))
361       return failure();
362 
363   // Forward all results of cluster first. These results will be remapped once
364   // a new cluster is formed.
365   cluster->replaceAllUsesWith(
366       cluster->GetBody().getTerminator()->getOperands());
367 
368   CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
369                         host_device);
370 
371   *cluster = UpdateClusterResults(builder, *cluster, cluster_results);
372 
373   *cluster_updated = true;
374   return success();
375 }
376 
377 // Removes aliased outputs in cluster from ops outside of cluster.
RemoveClusterAliasedOutputs(OpBuilder * builder,tf_device::ClusterOp cluster)378 void RemoveClusterAliasedOutputs(OpBuilder* builder,
379                                  tf_device::ClusterOp cluster) {
380   llvm::SmallVector<Value, 4> used_old_cluster_results;
381   llvm::SmallVector<Value, 4> new_cluster_results;
382   llvm::SmallVector<Type, 4> new_cluster_result_types;
383   Operation* cluster_terminator = cluster.GetBody().getTerminator();
384   for (auto result :
385        llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
386     Value cluster_terminator_operand = std::get<0>(result);
387     if (cluster_terminator_operand.getDefiningOp() &&
388         cluster.getOperation()->isProperAncestor(
389             cluster_terminator_operand.getDefiningOp())) {
390       new_cluster_results.push_back(cluster_terminator_operand);
391       new_cluster_result_types.push_back(cluster_terminator_operand.getType());
392       used_old_cluster_results.push_back(std::get<1>(result));
393     } else {
394       std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
395     }
396   }
397 
398   if (new_cluster_results.size() == cluster.getNumResults()) return;
399 
400   builder->setInsertionPoint(cluster);
401   auto new_cluster = builder->create<tf_device::ClusterOp>(
402       cluster.getLoc(), new_cluster_result_types,
403       /*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
404   new_cluster.body().takeBody(cluster.body());
405   new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
406 
407   for (auto result :
408        llvm::zip(used_old_cluster_results, new_cluster.getResults()))
409     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
410 
411   cluster.erase();
412 }
413 
414 struct TPUExtractHeadTailOutsideCompilation
415     : public PassWrapper<TPUExtractHeadTailOutsideCompilation,
416                          OperationPass<ModuleOp>> {
417   void runOnOperation() override;
418 };
419 
runOnOperation()420 void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
421   auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
422   // Get runtime devices information from the closest parent module.
423   auto module = getOperation();
424   mlir::TF::RuntimeDevices devices;
425   if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
426     return signalPassFailure();
427 
428   OpBuilder builder(&getContext());
429   llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
430   module.walk(
431       [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
432 
433   for (tf_device::ClusterOp cluster : clusters) {
434     std::string host_device;
435     bool cluster_updated = false;
436     if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
437                                           devices, cluster, &host_device,
438                                           &cluster_updated)) ||
439         failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
440                                           devices, host_device, &cluster,
441                                           &cluster_updated)))
442       return signalPassFailure();
443     if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
444   }
445 }
446 
447 }  // anonymous namespace
448 
449 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass()450 CreateTPUExtractHeadTailOutsideCompilationPass() {
451   return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
452 }
453 
454 static PassRegistration<TPUExtractHeadTailOutsideCompilation> pass(
455     "tf-tpu-extract-head-tail-outside-compilation",
456     "Extracts TPU head or tail outside compilation to separate "
457     "parallel_execute.");
458 
459 }  // namespace TFTPU
460 }  // namespace mlir
461