• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <tuple>
18 #include <type_traits>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Block.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/Operation.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/IR/Visitors.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
42 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
46 
47 namespace mlir {
48 namespace TFTPU {
49 
50 // This pass extracts a CPU computation cluster with `_xla_outside_compilation`
51 // annotation from the head or tail of a TPU cluster.
52 
53 namespace {
54 
55 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
56 
HasOutsideCompilationAttribute(Operation * op)57 bool HasOutsideCompilationAttribute(Operation* op) {
58   return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
59 }
60 
61 // Finds op that created a given value. If the value is a BlockArgument, this
62 // returns the owner of the Block.
GetOpOfValue(Value value)63 Operation* GetOpOfValue(Value value) {
64   if (auto block_arg = value.dyn_cast<BlockArgument>())
65     return block_arg.getOwner()->getParentOp();
66 
67   return value.getDefiningOp();
68 }
69 
70 // Checks if `op` is nested in `block`.
OpInBlock(Operation * op,Block * block)71 bool OpInBlock(Operation* op, Block* block) {
72   Block* op_block = op->getBlock();
73   while (op_block) {
74     if (op_block == block) return true;
75     if (auto* parent_op = op_block->getParentOp()) {
76       op_block = parent_op->getBlock();
77     } else {
78       break;
79     }
80   }
81   return false;
82 }
83 
84 // Wraps block in a Launch. External uses of ops in the block will be return
85 // values of the Launch and remapped to the Launch results. If `before` is set
86 // to true, the Launch is created before `op`. Otherwise the Launch is created
87 // after `op`.
CreateLaunchForBlock(OpBuilder * builder,Operation * op,bool before,Block * launch_block,llvm::StringRef host_device)88 tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
89                                          bool before, Block* launch_block,
90                                          llvm::StringRef host_device) {
91   // Find results and result types of ops in block that needs to returned.
92   llvm::SmallVector<Value, 4> launch_results;
93   llvm::SmallVector<Type, 4> launch_result_types;
94   for (Operation& head_outside_compiled_op : *launch_block) {
95     for (Value result : head_outside_compiled_op.getResults()) {
96       bool has_external_uses = false;
97       for (Operation* user : result.getUsers()) {
98         if (OpInBlock(user, launch_block)) continue;
99         has_external_uses = true;
100         break;
101       }
102       if (has_external_uses) {
103         launch_results.push_back(result);
104         launch_result_types.push_back(result.getType());
105       }
106     }
107   }
108 
109   before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
110   auto launch = builder->create<tf_device::LaunchOp>(
111       op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
112   launch.body().push_back(launch_block);
113 
114   builder->setInsertionPointToEnd(&launch.GetBody());
115   builder->create<tf_device::ReturnOp>(op->getLoc(), launch_results);
116 
117   return launch;
118 }
119 
120 // Checks if an operation is a supported TPU embedding op.
IsEmbeddingOp(Operation * op)121 bool IsEmbeddingOp(Operation* op) {
122   return isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
123              TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
124              TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp,
125              TF::RecvTPUEmbeddingActivationsOp,
126              TF::SendTPUEmbeddingGradientsOp>(op);
127 }
128 
129 // Returns a set of ops that are outside compiled and can be extracted to before
130 // the TPU computation. These ops are either connected to the inputs of the TPU
131 // computation or other ops that can be extracted, and have no operands from
132 // other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtHead(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster)133 llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
134     const TF::SideEffectAnalysis& side_effect_analysis,
135     tf_device::ClusterOp cluster) {
136   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
137       cluster->getParentOfType<func::FuncOp>());
138   Region* cluster_region = &cluster.body();
139   llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
140 
141   auto cluster_ops = cluster.GetBody().without_terminator();
142   for (Operation& cluster_op : cluster_ops) {
143     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
144     // An outside compiled op can be extracted if its operands are not from
145     // other ops in the cluster that cannot be extracted.
146 
147     // Check if the side effecting op right before this side effecting op, if
148     // it is side effecting, can be head extracted. Because of op ordering due
149     // to side effects, if this is not true, this op cannot be head extracted.
150     // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
151     // is in a topological sort order and depending on that ordering, embedding
152     // ops may prevent other ops from being head extracted.
153     auto predecessors = analysis.DirectControlPredecessors(&cluster_op);
154     if (!predecessors.empty() && !IsEmbeddingOp(&cluster_op)) {
155       bool skip = false;
156       for (Operation* predecessor : llvm::reverse(predecessors)) {
157         if (IsEmbeddingOp(predecessor)) continue;
158         skip = !head_outside_compiled_ops.contains(predecessor);
159         break;
160       }
161       if (skip) continue;
162     }
163 
164     auto walk_result = cluster_op.walk([&](Operation* op) {
165       for (Value operand : op->getOperands()) {
166         Operation* operand_op = GetOpOfValue(operand);
167         if (head_outside_compiled_ops.count(operand_op) ||
168             operand_op == &cluster_op)
169           continue;
170 
171         if (operand_op->getParentRegion() == cluster_region)
172           return WalkResult::interrupt();
173       }
174       return WalkResult::advance();
175     });
176 
177     if (!walk_result.wasInterrupted())
178       head_outside_compiled_ops.insert(&cluster_op);
179   }
180 
181   return head_outside_compiled_ops.takeVector();
182 }
183 
184 // Moves head outside compiled ops into its own `tf_device.LaunchOp`
185 // computation before the cluster.
CreateHeadComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > head_outside_compiled_ops,llvm::StringRef host_device)186 void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
187                            llvm::ArrayRef<Operation*> head_outside_compiled_ops,
188                            llvm::StringRef host_device) {
189   Block* launch_block = new Block;
190   for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
191     head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
192     head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
193   }
194 
195   tf_device::LaunchOp launch = CreateLaunchForBlock(
196       builder, cluster, /*before=*/true, launch_block, host_device);
197 
198   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
199                                launch.getResults()))
200     replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
201                                cluster.body());
202 }
203 
204 // Extracts and move outside compiled ops that have no dependencies in the
205 // cluster to before the cluster.
LiftHeadOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,tf_device::ClusterOp cluster,std::string * host_device,bool * cluster_updated)206 mlir::LogicalResult LiftHeadOutsideCompiledOps(
207     OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
208     const mlir::TF::RuntimeDevices& devices, tf_device::ClusterOp cluster,
209     std::string* host_device, bool* cluster_updated) {
210   llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
211       FindOutsideCompiledOpsAtHead(side_effect_analysis, cluster);
212   if (head_outside_compiled_ops.empty()) return success();
213   if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
214                                                          host_device)))
215     return failure();
216 
217   CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
218                         *host_device);
219 
220   *cluster_updated = true;
221   return success();
222 }
223 
224 // Fills `tail_outside_compiled_ops` with ops that are outside compiled and
225 // can be extracted to after the TPU computation, and `cluster_results` with new
226 // results of the cluster. These ops are either connected to the output of the
227 // TPU computation or other ops that can be extracted, and have no results used
228 // by other ops in the TPU computation that cannot be extracted.
FindOutsideCompiledOpsAtTailAndClusterResults(const TF::SideEffectAnalysis & side_effect_analysis,tf_device::ClusterOp cluster,llvm::SmallVectorImpl<Operation * > * tail_outside_compiled_ops,llvm::SmallVectorImpl<Value> * cluster_results)229 void FindOutsideCompiledOpsAtTailAndClusterResults(
230     const TF::SideEffectAnalysis& side_effect_analysis,
231     tf_device::ClusterOp cluster,
232     llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
233     llvm::SmallVectorImpl<Value>* cluster_results) {
234   const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
235       cluster->getParentOfType<func::FuncOp>());
236   Region* cluster_region = &cluster.body();
237   llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
238   Operation* terminator = cluster.GetBody().getTerminator();
239   llvm::SmallSetVector<Value, 4> cluster_results_set;
240   cluster_results_set.insert(terminator->getOperands().begin(),
241                              terminator->getOperands().end());
242 
243   auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
244   for (Operation& cluster_op : cluster_ops) {
245     if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
246 
247     // Check if the side effecting op right after this side effecting op, if
248     // it is side effecting, can be tail extracted. Because of op ordering due
249     // to side effects, if this is not true, this op cannot be tail extracted.
250     // TODO(lyandy): Remove special handling of embedding ops. Currently the IR
251     // is in a topological sort order and depending on that ordering, embedding
252     // ops may prevent other ops from being tail extracted.
253     auto successors = analysis.DirectControlSuccessors(
254         &cluster_op, [&terminator](Operation* op) { return op != terminator; });
255     if (!successors.empty() && !IsEmbeddingOp(&cluster_op)) {
256       bool skip = false;
257       for (Operation* successor : successors) {
258         if (IsEmbeddingOp(successor)) continue;
259         skip = !tail_outside_compiled_ops_set.contains(successor);
260         break;
261       }
262       if (skip) continue;
263     }
264 
265     llvm::SmallVector<int, 4> results_to_forward;
266     bool can_be_extracted =
267         llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
268           return op == terminator || tail_outside_compiled_ops_set.count(op);
269         });
270     if (!can_be_extracted) continue;
271 
272     // Collect operands of cluster op that are generated within the cluster.
273     // These values should be returned by the cluster.
274     cluster_op.walk([&](Operation* op) {
275       for (Value operand : op->getOperands()) {
276         Operation* operand_op = GetOpOfValue(operand);
277         if (operand_op->getParentRegion() == cluster_region)
278           cluster_results_set.insert(operand);
279       }
280     });
281 
282     // Remove results of op to be extracted as there are no uses in the cluster.
283     for (Value result : cluster_op.getResults())
284       cluster_results_set.remove(result);
285     // Insert all ops including nested ops for checking outputs/side effects.
286     cluster_op.walk(
287         [&](Operation* op) { tail_outside_compiled_ops_set.insert(op); });
288 
289     // Only add top level ops to output vector.
290     tail_outside_compiled_ops->push_back(&cluster_op);
291   }
292 
293   *cluster_results = cluster_results_set.takeVector();
294 }
295 
296 // Moves tail outside compiled ops into its own `tf_device.LaunchOp`
297 // computation after the cluster.
CreateTailComputation(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Operation * > tail_outside_compiled_ops,llvm::StringRef host_device)298 void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
299                            llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
300                            llvm::StringRef host_device) {
301   Block* launch_block = new Block;
302   for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
303     tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
304     tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
305   }
306 
307   tf_device::LaunchOp launch = CreateLaunchForBlock(
308       builder, cluster, /*before=*/false, launch_block, host_device);
309 
310   auto operand_not_in_launch = [&](OpOperand& operand) {
311     return !launch.getOperation()->isProperAncestor(operand.getOwner());
312   };
313   for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
314                                launch.getResults()))
315     std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
316                                           operand_not_in_launch);
317 }
318 
319 // Updates cluster with updated cluster results after extracting tail outside
320 // compiled ops.
UpdateClusterResults(OpBuilder * builder,tf_device::ClusterOp cluster,llvm::ArrayRef<Value> new_cluster_results)321 tf_device::ClusterOp UpdateClusterResults(
322     OpBuilder* builder, tf_device::ClusterOp cluster,
323     llvm::ArrayRef<Value> new_cluster_results) {
324   Operation* old_terminator = cluster.GetBody().getTerminator();
325   builder->setInsertionPoint(old_terminator);
326   builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
327                                        new_cluster_results);
328   old_terminator->erase();
329 
330   builder->setInsertionPoint(cluster);
331   llvm::SmallVector<Type, 4> new_cluster_result_types;
332   new_cluster_result_types.reserve(new_cluster_results.size());
333   for (const auto& new_cluster_result : new_cluster_results)
334     new_cluster_result_types.push_back(new_cluster_result.getType());
335 
336   auto new_cluster = builder->create<tf_device::ClusterOp>(
337       cluster.getLoc(), new_cluster_result_types,
338       /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
339   new_cluster.body().takeBody(cluster.body());
340 
341   auto operand_not_in_cluster = [&](OpOperand& operand) {
342     return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
343   };
344   for (auto result :
345        llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
346                  new_cluster.getResults()))
347     std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
348                                           operand_not_in_cluster);
349 
350   cluster.erase();
351   return new_cluster;
352 }
353 
354 // Extracts and move outside compiled ops that do not create dependencies in the
355 // cluster to after the cluster.
LiftTailOutsideCompiledOps(OpBuilder * builder,const TF::SideEffectAnalysis & side_effect_analysis,const mlir::TF::RuntimeDevices & devices,std::string host_device,tf_device::ClusterOp * cluster,bool * cluster_updated)356 mlir::LogicalResult LiftTailOutsideCompiledOps(
357     OpBuilder* builder, const TF::SideEffectAnalysis& side_effect_analysis,
358     const mlir::TF::RuntimeDevices& devices, std::string host_device,
359     tf_device::ClusterOp* cluster, bool* cluster_updated) {
360   llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
361   llvm::SmallVector<Value, 4> cluster_results;
362   FindOutsideCompiledOpsAtTailAndClusterResults(side_effect_analysis, *cluster,
363                                                 &tail_outside_compiled_ops,
364                                                 &cluster_results);
365   if (tail_outside_compiled_ops.empty()) return success();
366 
367   if (host_device.empty())
368     if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
369                                                            &host_device)))
370       return failure();
371 
372   // Forward all results of cluster first. These results will be remapped once
373   // a new cluster is formed.
374   cluster->replaceAllUsesWith(
375       cluster->GetBody().getTerminator()->getOperands());
376 
377   CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
378                         host_device);
379 
380   *cluster = UpdateClusterResults(builder, *cluster, cluster_results);
381 
382   *cluster_updated = true;
383   return success();
384 }
385 
386 // Removes aliased outputs in cluster from ops outside of cluster.
RemoveClusterAliasedOutputs(OpBuilder * builder,tf_device::ClusterOp cluster)387 void RemoveClusterAliasedOutputs(OpBuilder* builder,
388                                  tf_device::ClusterOp cluster) {
389   llvm::SmallVector<Value, 4> used_old_cluster_results;
390   llvm::SmallVector<Value, 4> new_cluster_results;
391   llvm::SmallVector<Type, 4> new_cluster_result_types;
392   Operation* cluster_terminator = cluster.GetBody().getTerminator();
393   for (auto result :
394        llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
395     Value cluster_terminator_operand = std::get<0>(result);
396     if (cluster_terminator_operand.getDefiningOp() &&
397         cluster.getOperation()->isProperAncestor(
398             cluster_terminator_operand.getDefiningOp())) {
399       new_cluster_results.push_back(cluster_terminator_operand);
400       new_cluster_result_types.push_back(cluster_terminator_operand.getType());
401       used_old_cluster_results.push_back(std::get<1>(result));
402     } else {
403       std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
404     }
405   }
406 
407   if (new_cluster_results.size() == cluster.getNumResults()) return;
408 
409   builder->setInsertionPoint(cluster);
410   auto new_cluster = builder->create<tf_device::ClusterOp>(
411       cluster.getLoc(), new_cluster_result_types,
412       /*operands=*/llvm::ArrayRef<Value>{}, cluster->getAttrs());
413   new_cluster.body().takeBody(cluster.body());
414   new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
415 
416   for (auto result :
417        llvm::zip(used_old_cluster_results, new_cluster.getResults()))
418     std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
419 
420   cluster.erase();
421 }
422 
423 // Checks if `type` is allowed for data on TPUs. String and resources cannot be
424 // assigned to TPUs. There are other TF types that are not allowed on TPUs, but
425 // these will be removed by successive passes in TF/XLA bridge phase 2.
TypeValidForTPU(Type type)426 bool TypeValidForTPU(Type type) {
427   Type elem = getElementTypeOrSelf(type);
428   return !elem.isa<TF::ResourceType>() && !elem.isa<TF::StringType>();
429 }
430 
431 // Check that cluster results are valid. An result is invalid when it does not
432 // have a valid XLA type.
CheckClusterResults(tf_device::ClusterOp cluster)433 LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) {
434   for (OpResult result : cluster.getResults()) {
435     if (!TypeValidForTPU(result.getType())) {
436       cluster.emitError()
437           << "The TPUExtractHeadTailOutsideCompilation pass produced a TPU "
438              "cluster with a result with a non-XLA type: "
439           << result.getType();
440       return failure();
441     }
442   }
443   return success();
444 }
445 
446 // Check the validity of the module, post-pass.
CheckPostconditions(ModuleOp module)447 LogicalResult CheckPostconditions(ModuleOp module) {
448   auto walk_result = module.walk([&](tf_device::ClusterOp cluster) {
449     if (failed(CheckClusterResults(cluster))) return WalkResult::interrupt();
450     return WalkResult::advance();
451   });
452   if (walk_result.wasInterrupted()) return failure();
453   return success();
454 }
455 
456 struct TPUExtractHeadTailOutsideCompilationPass
457     : public TF::TPUExtractHeadTailOutsideCompilationPassBase<
458           TPUExtractHeadTailOutsideCompilationPass> {
459   void runOnOperation() override;
460 };
461 
runOnOperation()462 void TPUExtractHeadTailOutsideCompilationPass::runOnOperation() {
463   auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
464   // Get runtime devices information from the closest parent module.
465   auto module = getOperation();
466   mlir::TF::RuntimeDevices devices;
467   if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
468     return signalPassFailure();
469 
470   OpBuilder builder(&getContext());
471   llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
472   module.walk(
473       [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
474 
475   for (tf_device::ClusterOp cluster : clusters) {
476     std::string host_device;
477     bool cluster_updated = false;
478     if (failed(LiftHeadOutsideCompiledOps(&builder, side_effect_analysis,
479                                           devices, cluster, &host_device,
480                                           &cluster_updated)) ||
481         failed(LiftTailOutsideCompiledOps(&builder, side_effect_analysis,
482                                           devices, host_device, &cluster,
483                                           &cluster_updated)))
484       return signalPassFailure();
485     if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
486   }
487 
488   if (failed(CheckPostconditions(module))) return signalPassFailure();
489 }
490 
491 }  // anonymous namespace
492 
493 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass()494 CreateTPUExtractHeadTailOutsideCompilationPass() {
495   return std::make_unique<TPUExtractHeadTailOutsideCompilationPass>();
496 }
497 
498 }  // namespace TFTPU
499 }  // namespace mlir
500