• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "mlir/IR/Attributes.h"  // TF:llvm-project
29 #include "mlir/IR/Builders.h"  // TF:llvm-project
30 #include "mlir/IR/Module.h"  // TF:llvm-project
31 #include "mlir/IR/Operation.h"  // TF:llvm-project
32 #include "mlir/IR/StandardTypes.h"  // TF:llvm-project
33 #include "mlir/IR/Types.h"  // TF:llvm-project
34 #include "mlir/Pass/Pass.h"  // TF:llvm-project
35 #include "mlir/Pass/PassRegistry.h"  // TF:llvm-project
36 #include "mlir/Support/LogicalResult.h"  // TF:llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
45 #include "tensorflow/compiler/xla/xla.pb.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/framework/tensor_shape.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
52 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
53 #include "tensorflow/core/util/device_name_utils.h"
54 
55 namespace mlir {
56 namespace TFTPU {
57 
58 // NOLINTNEXTLINE
59 static llvm::cl::opt<bool> tpu_compile_metadata_debug(
60     "tpu_compile_metadata_debug",
61     llvm::cl::desc("Serialize TPUCompileMetadataProto metadata in "
62                    "'tf._TPUCompileMlir' op as a proto debug string"));
63 
64 constexpr char kNumReplicasAttr[] = "num_replicas";
65 constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
66 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
67 constexpr char kPaddingMapAttr[] = "padding_map";
68 constexpr char kDeviceAttr[] = "device";
69 constexpr char kDevicesAttr[] = "devices";
70 constexpr char kVersionsAttr[] = "tf.versions";
71 
72 // Rewrites `tf_device.launch_func` operations assigned to TPU into actual TPU
73 // jit-compile runtime ops.
74 //
75 // For example:
76 //   %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster", func =
77 //         @tpu_func}
78 //   %2 = "tf.SomeOp"(%1)
79 //
80 // Would become following ops (unimportant attributes, types are omitted):
81 //    %1 = "tf.Shape"(%0)
82 //    %2:2 = "tf.MLIRCompileToTPU"(%1) {module = "<Serialized @tpu_func>"}
83 //    "tf.TPUCompileSucceededAssert"(%2#0)
84 //    %3 = "tf.TPUExecute"(%0, %2#1)
85 //    %4 = "tf.SomeOp"(%3)
86 
87 namespace {
88 struct TPURewritePass : public ModulePass<TPURewritePass> {
89   void runOnModule() override;
90 };
91 
92 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)93 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
94   return llvm::formatv("requires attribute '{0}'", attribute).str();
95 }
96 
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)97 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
98                                           std::string* serialized_func_module) {
99   ModuleOp module = entry_func.getParentOfType<ModuleOp>();
100   SymbolTable entry_module_table(module);
101   llvm::SmallVector<FuncOp, 4> referenced({entry_func});
102 
103   // Create a new module to hold func and all referenced functions.
104   OwningModuleRef module_for_func =
105       ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
106   auto parent_module = entry_func.getParentOfType<ModuleOp>();
107   auto versions_attr = parent_module.getAttr(kVersionsAttr);
108   if (!versions_attr)
109     return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
110 
111   module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
112   SymbolTable symbol_table(module_for_func.get());
113 
114   while (!referenced.empty()) {
115     auto func = referenced.pop_back_val();
116 
117     // Skip functions that have already been cloned into new module.
118     if (symbol_table.lookup<FuncOp>(func.getName())) continue;
119 
120     // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
121     // all found FuncOps to new_module to make sure new_module is
122     // self-contained.
123     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
124     assert(uses && "expected to be able to collect symbol uses");
125     for (SymbolTable::SymbolUse use : *uses) {
126       FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
127           use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
128 
129       // Skip Symbols that do not map to a function.
130       if (!referenced_func) continue;
131 
132       referenced.emplace_back(referenced_func);
133     }
134 
135     auto clone = func.clone();
136     if (clone.getName() == entry_func.getName()) {
137       // We can simply change name of TPU program's main function because there
138       // should be no other reference to it.
139       clone.setName("main");
140     }
141     symbol_table.insert(clone);
142   }
143 
144   // Serialize module and return.
145   {
146     llvm::raw_string_ostream os(*serialized_func_module);
147     module_for_func.get().print(os);
148   }
149   return success();
150 }
151 
152 // Populates a TPUCompileMetadataProto from attributes of a
153 // `tf_device::LaunchFuncOp`. If any necessary attributes are missing from the
154 // op, a failure will be returned.
155 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromLaunchFuncOp(tf_device::LaunchFuncOp op,int num_replicas,int num_cores_per_replica,tensorflow::tpu::TPUCompileMetadataProto * metadata)156 LogicalResult SetMetadataProtoFromLaunchFuncOp(
157     tf_device::LaunchFuncOp op, int num_replicas, int num_cores_per_replica,
158     tensorflow::tpu::TPUCompileMetadataProto* metadata) {
159   metadata->set_num_replicas(num_replicas);
160   metadata->set_num_cores_per_replica(num_cores_per_replica);
161 
162   auto step_marker_location =
163       op.getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
164   if (!step_marker_location)
165     return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
166 
167   // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
168   // empty.
169   xla::DebugOptions::StepMarkerLocation location =
170       xla::DebugOptions::STEP_MARK_AT_ENTRY;
171   if (!step_marker_location.getValue().empty() &&
172       !xla::DebugOptions::StepMarkerLocation_Parse(
173           std::string(step_marker_location.getValue()), &location))
174     return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
175                                         kStepMarkerLocationAttr,
176                                         step_marker_location.getValue()));
177 
178   metadata->set_step_marker_location(location);
179 
180   auto padding_map = op.getAttrOfType<ArrayAttr>(kPaddingMapAttr);
181   if (!padding_map)
182     return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
183 
184   for (const auto padding_and_idx : llvm::enumerate(padding_map)) {
185     auto& padding_attr = padding_and_idx.value();
186     auto padding_attr_str = padding_attr.dyn_cast<StringAttr>();
187     if (!padding_attr_str)
188       return op.emitOpError(
189           llvm::formatv("bad '{0}' attribute at index {1}, not a string",
190                         kPaddingMapAttr, padding_and_idx.index()));
191 
192     tensorflow::tpu::PaddingMap* padding =
193         metadata->mutable_padding_maps()->Add();
194     if (!padding->ParseFromString(std::string(padding_attr_str.getValue())))
195       return op.emitOpError(llvm::formatv(
196           "bad '{0}' attribute at index {1} with value '{2}'", kPaddingMapAttr,
197           padding_and_idx.index(), padding_attr_str.getValue()));
198   }
199 
200   // Set args metadata in proto.
201   for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
202     Type operand_type = operand_type_and_idx.value();
203     tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
204     tensorflow::DataType dtype;
205     tensorflow::Status status =
206         tensorflow::ConvertToDataType(operand_type, &dtype);
207     if (!status.ok())
208       return op.emitOpError(
209           llvm::formatv("failed to determine operand type at index {0}: {1}",
210                         operand_type_and_idx.index(), status.error_message()));
211 
212     arg->set_dtype(dtype);
213     // TODO(lyandy): Support other arg kinds.
214     if (dtype == tensorflow::DT_RESOURCE)
215       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
216     else
217       arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
218 
219     // Populate argument shapes.
220     *arg->mutable_shape() = tensorflow::TensorShapeProto();
221     if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
222       tensorflow::TensorShapeProto shape_proto;
223       ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
224       *arg->mutable_shape() = std::move(shape_proto);
225     } else {
226       arg->mutable_shape()->set_unknown_rank(true);
227     }
228 
229     // TODO(lyandy): Determine proper sharding of args once topology and devices
230     // are propagated to the pass.
231     xla::OpSharding sharding;
232     sharding.set_type(xla::OpSharding::MAXIMAL);
233     sharding.add_tile_assignment_dimensions(1);
234     sharding.add_tile_assignment_devices(0);
235     *arg->mutable_sharding() = std::move(sharding);
236   }
237 
238   // Set retvals metadata in proto.
239   // TODO(lyandy): Determine proper sharding of retvals once topology and
240   // devices is propagated to the pass.
241   for (int i = 0; i < op.getNumResults(); ++i) {
242     xla::OpSharding sharding;
243     sharding.set_type(xla::OpSharding::MAXIMAL);
244     sharding.add_tile_assignment_dimensions(1);
245     sharding.add_tile_assignment_devices(0);
246     *metadata->add_retvals()->mutable_sharding() = std::move(sharding);
247   }
248 
249   return success();
250 }
251 
252 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
253 // functionally equivalent to the function referenced by launch_func.
BuildCompileOp(tf_device::LaunchFuncOp launch_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,OpBuilder * builder)254 Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func, int num_replicas,
255                           int num_cores_per_replica,
256                           llvm::StringRef compilation_device,
257                           OpBuilder* builder) {
258   // TODO(b/139377366): Use tf_tpu.compile build method when it is defined.
259   OperationState compile_op_state(launch_func.getLoc(), "tf._TPUCompileMlir");
260 
261   // Set metadata from attributes.
262   tensorflow::tpu::TPUCompileMetadataProto metadata;
263   if (failed(SetMetadataProtoFromLaunchFuncOp(
264           launch_func, num_replicas, num_cores_per_replica, &metadata)))
265     return nullptr;
266 
267   std::string txt_metadata;
268   if (tpu_compile_metadata_debug)
269     txt_metadata = metadata.DebugString();
270   else
271     metadata.SerializeToString(&txt_metadata);
272 
273   compile_op_state.addAttribute("metadata",
274                                 builder->getStringAttr(txt_metadata));
275 
276   // Build a shape op for each input to launch_func.
277   // TODO(b/139377366): When shape inference is ready, we can use compile time
278   // shape inference to get inputs that have static shapes and only use shape
279   // ops for the rest.
280   llvm::SmallVector<Value, 4> compile_op_operands;
281   compile_op_operands.reserve(launch_func.getNumOperands());
282 
283   for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) {
284     // Skip adding shape op for operands that have static shapes.
285     tensorflow::PartialTensorShape shape(
286         metadata.args(operand_and_idx.index()).shape());
287     if (shape.IsFullyDefined()) continue;
288 
289     auto shape_op = builder->create<TF::ShapeOp>(
290         launch_func.getLoc(),
291         RankedTensorType::get({-1}, builder->getIntegerType(64)),
292         operand_and_idx.value());
293     compile_op_operands.emplace_back(shape_op.getResult());
294   }
295   compile_op_state.addOperands(compile_op_operands);
296   compile_op_state.addAttribute(
297       "NumDynamicShapes",
298       builder->getI64IntegerAttr(compile_op_operands.size()));
299 
300   FlatSymbolRefAttr func_attr =
301       launch_func.getAttrOfType<FlatSymbolRefAttr>("func");
302   if (!func_attr) {
303     launch_func.emitOpError("does not have `func` attribute");
304     return nullptr;
305   }
306   FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
307       func_attr.getValue());
308 
309   std::string txt_module;
310   if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
311   compile_op_state.addAttribute("mlir_module",
312                                 builder->getStringAttr(txt_module));
313 
314   compile_op_state.addAttribute(kDeviceAttr,
315                                 builder->getStringAttr(compilation_device));
316 
317   // Result #0 is a string indicating whether compilation is successful or not.
318   compile_op_state.addTypes(
319       RankedTensorType::get({}, builder->getType<TF::StringType>()));
320 
321   // Result #1 is key to look up executable binary in compilation cache.
322   compile_op_state.addTypes(
323       RankedTensorType::get({}, builder->getType<TF::StringType>()));
324 
325   return builder->createOperation(compile_op_state);
326 }
327 
328 // Creates a `tf.TPUExecute` op that executes TPU program generated by
329 // `compile_op`.
BuildExecuteOp(Operation * compile_op,tf_device::LaunchFuncOp launch_func,OpBuilder * builder)330 Operation* BuildExecuteOp(Operation* compile_op,
331                           tf_device::LaunchFuncOp launch_func,
332                           OpBuilder* builder) {
333   // TPUExecute inherits all launch_func inputs, and takes an additional input
334   // for compilation cache key.
335   llvm::SmallVector<Value, 4> tensor_inputs(launch_func.getOperands());
336   tensor_inputs.push_back(compile_op->getResult(1));
337 
338   // TODO(b/139377366): Need to snapshot all resource variable inputs in
339   // follow-up CLs.
340 
341   // TPUExecute has same output types as launch_func.
342   return builder->create<TF::TPUExecuteOp>(
343       launch_func.getLoc(), launch_func.getResultTypes(), tensor_inputs,
344       llvm::ArrayRef<NamedAttribute>{});
345 }
346 
347 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
348 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,OpBuilder * builder)349 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
350                                       OpBuilder* builder) {
351   OperationState assert_op_state(compile_op->getLoc(),
352                                  "tf.TPUCompileSucceededAssert");
353   assert_op_state.addOperands(compile_op->getResult(0));
354   builder->createOperation(assert_op_state);
355 }
356 
357 // Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime
358 // Operations that jit-compiles and executes function in `tf_device.launch_func`
359 // on TPU. Device assignment is determined from available devices in `devices`.
360 // If it is not possible to rewrite the operation or device assignment fails, a
361 // failure will be returned.
362 //
363 // For example, a non replicated `tf_device.launch_func`:
364 //
365 // func @main(%arg0: tensor<i1>) {
366 //   %0 = "tf_device.launch_func"(%arg0)
367 //          {_tpu_replicate = "cluster0", device = "", func = @_func} :
368 //          (tensor<i1>) -> tensor<i1>
369 //   return
370 // }
371 //
372 // will be rewritten as:
373 //
374 // func @main(%arg0: tensor<i1>) {
375 //   %0 = "tf.Shape"(%arg0) : (tensor<i1>) -> tensor<?xi32>
376 //   %1:2 = "tf._TPUCompileMlir"(%0) {device = "/CPU:0"} :
377 //            (tensor<?xi32>) -> (tensor<!tf.string>, tensor<!tf.string>)
378 //   %2 = "tf.TPUExecute"(%arg0, %1#0) {device = "/TPU:0"} :
379 //            (tensor<i1>, tensor<!tf.string>) -> tensor<i1>
380 //   return
381 // }
382 //
383 // and a replicated `tf_device.launch_func`:
384 //
385 // func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
386 //   %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
387 //                              {n = 2 : i32} {
388 //     %1 = "tf_device.launch_func"(%ri)
389 //            {_tpu_replicate = "cluster0", device = "", func = @_func} :
390 //            (tensor<i1>) -> tensor<i1>
391 //     tf_device.return %1 : tensor<i1>
392 //   }
393 //   return
394 // }
395 //
396 // will be rewritten as:
397 //
398 // func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
399 //   %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
400 //                              {n = 2 : i32, devices = ["/TPU:0", "/TPU:1"]} {
401 //     %1 = "tf.Shape"(%ri) : (tensor<i1>) -> tensor<?xi32>
402 //     %2:2 = "tf._TPUCompileMlir"(%1) {device = "/CPU:0"} :
403 //              (tensor<?xi32>) -> (tensor<!tf.string>, tensor<!tf.string>)
404 //     %3 = "tf.TPUExecute"(%ri, %2#0) :
405 //            (tensor<i1>, tensor<!tf.string>) -> tensor<i1>
406 //     tf_device.return %3 : tensor<i1>
407 //   }
408 //   return
409 // }
Rewrite(tf_device::LaunchFuncOp launch_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,OpBuilder * builder)410 LogicalResult Rewrite(
411     tf_device::LaunchFuncOp launch_func,
412     llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
413     OpBuilder* builder) {
414   // Skip non-tpu device launch_func.
415   auto replicate_attr = launch_func.getAttrOfType<StringAttr>("_tpu_replicate");
416   if (!replicate_attr) return success();
417 
418   // Collect `num_replicas` and `num_cores_per_replica` attributes.
419   int num_replicas = 1;
420   tf_device::ReplicateOp replicate =
421       launch_func.getParentOp()
422           ? llvm::dyn_cast_or_null<tf_device::ReplicateOp>(
423                 launch_func.getParentOp())
424           : nullptr;
425   if (replicate) num_replicas = replicate.n().getLimitedValue();
426 
427   auto num_cores_per_replica_attr =
428       launch_func.getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr);
429   if (!num_cores_per_replica_attr)
430     return launch_func.emitOpError(
431         CreateMissingAttributeMsg(kNumCoresPerReplicaAttr));
432 
433   int num_cores_per_replica = num_cores_per_replica_attr.getInt();
434 
435   // Determine compilation and execution devices.
436   std::string compilation_device;
437   llvm::SmallVector<std::string, 8> execution_devices;
438   auto status = tensorflow::GetTPUCompilationAndExecutionDevices(
439       devices, num_replicas, num_cores_per_replica, &compilation_device,
440       &execution_devices);
441   if (!status.ok())
442     return launch_func.emitError()
443            << "error in fetching TPU compilation/execution devices: "
444            << status.error_message();
445 
446   // Create compile op;
447   builder->setInsertionPoint(launch_func);
448   Operation* compile_op =
449       BuildCompileOp(launch_func, num_replicas, num_cores_per_replica,
450                      compilation_device, builder);
451   if (!compile_op) return failure();
452 
453   // After rewrite, find if there is a TPUCompilationResultOp in the block with
454   // the same _tpu_replicate attribute and replace it with the result of the
455   // compile op. This op is used as a placeholder to hook during graph creation
456   // the other ops that are intended to consume the compile result.
457   Block* block = launch_func.getOperation()->getBlock();
458   for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
459     compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
460 
461   BuildTPUCompileSucceededAssertOp(compile_op, builder);
462 
463   // Create execute op.
464   Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder);
465   launch_func.replaceAllUsesWith(execute_op);
466   launch_func.erase();
467 
468   // If computation is replicated, execution devices are assigned to the
469   // replicate. Otherwise there is only one execution device and the device is
470   // assigned to the execute op.
471   if (replicate) {
472     llvm::SmallVector<llvm::StringRef, 8> execution_device_refs(
473         execution_devices.begin(), execution_devices.end());
474     replicate.setAttr(kDevicesAttr,
475                       builder->getStrArrayAttr(execution_device_refs));
476   } else {
477     execute_op->setAttr(kDeviceAttr,
478                         builder->getStringAttr(execution_devices.front()));
479   }
480 
481   return success();
482 }
483 
runOnModule()484 void TPURewritePass::runOnModule() {
485   llvm::SmallVector<tensorflow::DeviceNameUtils::ParsedName, 8> devices;
486   if (failed(tensorflow::GetDevicesFromOp(getModule(), &devices)))
487     return signalPassFailure();
488 
489   OpBuilder builder(&getContext());
490   auto result = getModule().walk([&](tf_device::LaunchFuncOp op) {
491     if (failed(Rewrite(op, devices, &builder))) return WalkResult::interrupt();
492 
493     return WalkResult::advance();
494   });
495 
496   if (result.wasInterrupted()) return signalPassFailure();
497 
498   // Eliminate TPUCompilationResultOp now that the rewrite is complete.
499   getModule().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
500 
501   // TODO(b/139377366): Remove functions that are no longer needed.
502 }
503 
504 }  // namespace
505 
CreateTPURewritePass()506 std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass() {
507   return std::make_unique<TPURewritePass>();
508 }
509 
510 static PassRegistration<TPURewritePass> pass(
511     "tf-tpu-rewrite",
512     "Rewriting `tf_device.launch_func` on TPUs into TPU runtime ops");
513 
514 }  // namespace TFTPU
515 }  // namespace mlir
516