1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <string>
18 #include <type_traits>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "mlir/IR/Attributes.h" // TF:llvm-project
29 #include "mlir/IR/Builders.h" // TF:llvm-project
30 #include "mlir/IR/Module.h" // TF:llvm-project
31 #include "mlir/IR/Operation.h" // TF:llvm-project
32 #include "mlir/IR/StandardTypes.h" // TF:llvm-project
33 #include "mlir/IR/Types.h" // TF:llvm-project
34 #include "mlir/Pass/Pass.h" // TF:llvm-project
35 #include "mlir/Pass/PassRegistry.h" // TF:llvm-project
36 #include "mlir/Support/LogicalResult.h" // TF:llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
40 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
41 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
45 #include "tensorflow/compiler/xla/xla.pb.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/framework/tensor_shape.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/framework/types.pb.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
52 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
53 #include "tensorflow/core/util/device_name_utils.h"
54
55 namespace mlir {
56 namespace TFTPU {
57
58 // NOLINTNEXTLINE
59 static llvm::cl::opt<bool> tpu_compile_metadata_debug(
60 "tpu_compile_metadata_debug",
61 llvm::cl::desc("Serialize TPUCompileMetadataProto metadata in "
62 "'tf._TPUCompileMlir' op as a proto debug string"));
63
64 constexpr char kNumReplicasAttr[] = "num_replicas";
65 constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
66 constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
67 constexpr char kPaddingMapAttr[] = "padding_map";
68 constexpr char kDeviceAttr[] = "device";
69 constexpr char kDevicesAttr[] = "devices";
70 constexpr char kVersionsAttr[] = "tf.versions";
71
72 // Rewrites `tf_device.launch_func` operations assigned to TPU into actual TPU
73 // jit-compile runtime ops.
74 //
75 // For example:
76 // %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster", func =
77 // @tpu_func}
78 // %2 = "tf.SomeOp"(%1)
79 //
80 // Would become following ops (unimportant attributes, types are omitted):
81 // %1 = "tf.Shape"(%0)
82 // %2:2 = "tf.MLIRCompileToTPU"(%1) {module = "<Serialized @tpu_func>"}
83 // "tf.TPUCompileSucceededAssert"(%2#0)
84 // %3 = "tf.TPUExecute"(%0, %2#1)
85 // %4 = "tf.SomeOp"(%3)
86
87 namespace {
88 struct TPURewritePass : public ModulePass<TPURewritePass> {
89 void runOnModule() override;
90 };
91
92 // Creates a missing attribute error message.
CreateMissingAttributeMsg(llvm::StringRef attribute)93 std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
94 return llvm::formatv("requires attribute '{0}'", attribute).str();
95 }
96
EncapsulateFuncAndSerialize(FuncOp entry_func,std::string * serialized_func_module)97 LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
98 std::string* serialized_func_module) {
99 ModuleOp module = entry_func.getParentOfType<ModuleOp>();
100 SymbolTable entry_module_table(module);
101 llvm::SmallVector<FuncOp, 4> referenced({entry_func});
102
103 // Create a new module to hold func and all referenced functions.
104 OwningModuleRef module_for_func =
105 ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
106 auto parent_module = entry_func.getParentOfType<ModuleOp>();
107 auto versions_attr = parent_module.getAttr(kVersionsAttr);
108 if (!versions_attr)
109 return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
110
111 module_for_func.get().getOperation()->setAttr(kVersionsAttr, versions_attr);
112 SymbolTable symbol_table(module_for_func.get());
113
114 while (!referenced.empty()) {
115 auto func = referenced.pop_back_val();
116
117 // Skip functions that have already been cloned into new module.
118 if (symbol_table.lookup<FuncOp>(func.getName())) continue;
119
120 // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
121 // all found FuncOps to new_module to make sure new_module is
122 // self-contained.
123 Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
124 assert(uses && "expected to be able to collect symbol uses");
125 for (SymbolTable::SymbolUse use : *uses) {
126 FuncOp referenced_func = entry_module_table.lookup<FuncOp>(
127 use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
128
129 // Skip Symbols that do not map to a function.
130 if (!referenced_func) continue;
131
132 referenced.emplace_back(referenced_func);
133 }
134
135 auto clone = func.clone();
136 if (clone.getName() == entry_func.getName()) {
137 // We can simply change name of TPU program's main function because there
138 // should be no other reference to it.
139 clone.setName("main");
140 }
141 symbol_table.insert(clone);
142 }
143
144 // Serialize module and return.
145 {
146 llvm::raw_string_ostream os(*serialized_func_module);
147 module_for_func.get().print(os);
148 }
149 return success();
150 }
151
152 // Populates a TPUCompileMetadataProto from attributes of a
153 // `tf_device::LaunchFuncOp`. If any necessary attributes are missing from the
154 // op, a failure will be returned.
155 // TODO(lyandy): Support session handle and guaranteed consts.
SetMetadataProtoFromLaunchFuncOp(tf_device::LaunchFuncOp op,int num_replicas,int num_cores_per_replica,tensorflow::tpu::TPUCompileMetadataProto * metadata)156 LogicalResult SetMetadataProtoFromLaunchFuncOp(
157 tf_device::LaunchFuncOp op, int num_replicas, int num_cores_per_replica,
158 tensorflow::tpu::TPUCompileMetadataProto* metadata) {
159 metadata->set_num_replicas(num_replicas);
160 metadata->set_num_cores_per_replica(num_cores_per_replica);
161
162 auto step_marker_location =
163 op.getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
164 if (!step_marker_location)
165 return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
166
167 // Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
168 // empty.
169 xla::DebugOptions::StepMarkerLocation location =
170 xla::DebugOptions::STEP_MARK_AT_ENTRY;
171 if (!step_marker_location.getValue().empty() &&
172 !xla::DebugOptions::StepMarkerLocation_Parse(
173 std::string(step_marker_location.getValue()), &location))
174 return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
175 kStepMarkerLocationAttr,
176 step_marker_location.getValue()));
177
178 metadata->set_step_marker_location(location);
179
180 auto padding_map = op.getAttrOfType<ArrayAttr>(kPaddingMapAttr);
181 if (!padding_map)
182 return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
183
184 for (const auto padding_and_idx : llvm::enumerate(padding_map)) {
185 auto& padding_attr = padding_and_idx.value();
186 auto padding_attr_str = padding_attr.dyn_cast<StringAttr>();
187 if (!padding_attr_str)
188 return op.emitOpError(
189 llvm::formatv("bad '{0}' attribute at index {1}, not a string",
190 kPaddingMapAttr, padding_and_idx.index()));
191
192 tensorflow::tpu::PaddingMap* padding =
193 metadata->mutable_padding_maps()->Add();
194 if (!padding->ParseFromString(std::string(padding_attr_str.getValue())))
195 return op.emitOpError(llvm::formatv(
196 "bad '{0}' attribute at index {1} with value '{2}'", kPaddingMapAttr,
197 padding_and_idx.index(), padding_attr_str.getValue()));
198 }
199
200 // Set args metadata in proto.
201 for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
202 Type operand_type = operand_type_and_idx.value();
203 tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
204 tensorflow::DataType dtype;
205 tensorflow::Status status =
206 tensorflow::ConvertToDataType(operand_type, &dtype);
207 if (!status.ok())
208 return op.emitOpError(
209 llvm::formatv("failed to determine operand type at index {0}: {1}",
210 operand_type_and_idx.index(), status.error_message()));
211
212 arg->set_dtype(dtype);
213 // TODO(lyandy): Support other arg kinds.
214 if (dtype == tensorflow::DT_RESOURCE)
215 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
216 else
217 arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
218
219 // Populate argument shapes.
220 *arg->mutable_shape() = tensorflow::TensorShapeProto();
221 if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
222 tensorflow::TensorShapeProto shape_proto;
223 ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
224 *arg->mutable_shape() = std::move(shape_proto);
225 } else {
226 arg->mutable_shape()->set_unknown_rank(true);
227 }
228
229 // TODO(lyandy): Determine proper sharding of args once topology and devices
230 // are propagated to the pass.
231 xla::OpSharding sharding;
232 sharding.set_type(xla::OpSharding::MAXIMAL);
233 sharding.add_tile_assignment_dimensions(1);
234 sharding.add_tile_assignment_devices(0);
235 *arg->mutable_sharding() = std::move(sharding);
236 }
237
238 // Set retvals metadata in proto.
239 // TODO(lyandy): Determine proper sharding of retvals once topology and
240 // devices is propagated to the pass.
241 for (int i = 0; i < op.getNumResults(); ++i) {
242 xla::OpSharding sharding;
243 sharding.set_type(xla::OpSharding::MAXIMAL);
244 sharding.add_tile_assignment_dimensions(1);
245 sharding.add_tile_assignment_devices(0);
246 *metadata->add_retvals()->mutable_sharding() = std::move(sharding);
247 }
248
249 return success();
250 }
251
252 // Create a `tf._TPUCompileMlir` that contains a MLIR module that is
253 // functionally equivalent to the function referenced by launch_func.
BuildCompileOp(tf_device::LaunchFuncOp launch_func,int num_replicas,int num_cores_per_replica,llvm::StringRef compilation_device,OpBuilder * builder)254 Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func, int num_replicas,
255 int num_cores_per_replica,
256 llvm::StringRef compilation_device,
257 OpBuilder* builder) {
258 // TODO(b/139377366): Use tf_tpu.compile build method when it is defined.
259 OperationState compile_op_state(launch_func.getLoc(), "tf._TPUCompileMlir");
260
261 // Set metadata from attributes.
262 tensorflow::tpu::TPUCompileMetadataProto metadata;
263 if (failed(SetMetadataProtoFromLaunchFuncOp(
264 launch_func, num_replicas, num_cores_per_replica, &metadata)))
265 return nullptr;
266
267 std::string txt_metadata;
268 if (tpu_compile_metadata_debug)
269 txt_metadata = metadata.DebugString();
270 else
271 metadata.SerializeToString(&txt_metadata);
272
273 compile_op_state.addAttribute("metadata",
274 builder->getStringAttr(txt_metadata));
275
276 // Build a shape op for each input to launch_func.
277 // TODO(b/139377366): When shape inference is ready, we can use compile time
278 // shape inference to get inputs that have static shapes and only use shape
279 // ops for the rest.
280 llvm::SmallVector<Value, 4> compile_op_operands;
281 compile_op_operands.reserve(launch_func.getNumOperands());
282
283 for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) {
284 // Skip adding shape op for operands that have static shapes.
285 tensorflow::PartialTensorShape shape(
286 metadata.args(operand_and_idx.index()).shape());
287 if (shape.IsFullyDefined()) continue;
288
289 auto shape_op = builder->create<TF::ShapeOp>(
290 launch_func.getLoc(),
291 RankedTensorType::get({-1}, builder->getIntegerType(64)),
292 operand_and_idx.value());
293 compile_op_operands.emplace_back(shape_op.getResult());
294 }
295 compile_op_state.addOperands(compile_op_operands);
296 compile_op_state.addAttribute(
297 "NumDynamicShapes",
298 builder->getI64IntegerAttr(compile_op_operands.size()));
299
300 FlatSymbolRefAttr func_attr =
301 launch_func.getAttrOfType<FlatSymbolRefAttr>("func");
302 if (!func_attr) {
303 launch_func.emitOpError("does not have `func` attribute");
304 return nullptr;
305 }
306 FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
307 func_attr.getValue());
308
309 std::string txt_module;
310 if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr;
311 compile_op_state.addAttribute("mlir_module",
312 builder->getStringAttr(txt_module));
313
314 compile_op_state.addAttribute(kDeviceAttr,
315 builder->getStringAttr(compilation_device));
316
317 // Result #0 is a string indicating whether compilation is successful or not.
318 compile_op_state.addTypes(
319 RankedTensorType::get({}, builder->getType<TF::StringType>()));
320
321 // Result #1 is key to look up executable binary in compilation cache.
322 compile_op_state.addTypes(
323 RankedTensorType::get({}, builder->getType<TF::StringType>()));
324
325 return builder->createOperation(compile_op_state);
326 }
327
328 // Creates a `tf.TPUExecute` op that executes TPU program generated by
329 // `compile_op`.
BuildExecuteOp(Operation * compile_op,tf_device::LaunchFuncOp launch_func,OpBuilder * builder)330 Operation* BuildExecuteOp(Operation* compile_op,
331 tf_device::LaunchFuncOp launch_func,
332 OpBuilder* builder) {
333 // TPUExecute inherits all launch_func inputs, and takes an additional input
334 // for compilation cache key.
335 llvm::SmallVector<Value, 4> tensor_inputs(launch_func.getOperands());
336 tensor_inputs.push_back(compile_op->getResult(1));
337
338 // TODO(b/139377366): Need to snapshot all resource variable inputs in
339 // follow-up CLs.
340
341 // TPUExecute has same output types as launch_func.
342 return builder->create<TF::TPUExecuteOp>(
343 launch_func.getLoc(), launch_func.getResultTypes(), tensor_inputs,
344 llvm::ArrayRef<NamedAttribute>{});
345 }
346
347 // Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
348 // status of `compile_op` to check whether compilation is successful.
BuildTPUCompileSucceededAssertOp(Operation * compile_op,OpBuilder * builder)349 void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
350 OpBuilder* builder) {
351 OperationState assert_op_state(compile_op->getLoc(),
352 "tf.TPUCompileSucceededAssert");
353 assert_op_state.addOperands(compile_op->getResult(0));
354 builder->createOperation(assert_op_state);
355 }
356
357 // Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime
358 // Operations that jit-compiles and executes function in `tf_device.launch_func`
359 // on TPU. Device assignment is determined from available devices in `devices`.
360 // If it is not possible to rewrite the operation or device assignment fails, a
361 // failure will be returned.
362 //
363 // For example, a non replicated `tf_device.launch_func`:
364 //
365 // func @main(%arg0: tensor<i1>) {
366 // %0 = "tf_device.launch_func"(%arg0)
367 // {_tpu_replicate = "cluster0", device = "", func = @_func} :
368 // (tensor<i1>) -> tensor<i1>
369 // return
370 // }
371 //
372 // will be rewritten as:
373 //
374 // func @main(%arg0: tensor<i1>) {
375 // %0 = "tf.Shape"(%arg0) : (tensor<i1>) -> tensor<?xi32>
376 // %1:2 = "tf._TPUCompileMlir"(%0) {device = "/CPU:0"} :
377 // (tensor<?xi32>) -> (tensor<!tf.string>, tensor<!tf.string>)
378 // %2 = "tf.TPUExecute"(%arg0, %1#0) {device = "/TPU:0"} :
379 // (tensor<i1>, tensor<!tf.string>) -> tensor<i1>
380 // return
381 // }
382 //
383 // and a replicated `tf_device.launch_func`:
384 //
385 // func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
386 // %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
387 // {n = 2 : i32} {
388 // %1 = "tf_device.launch_func"(%ri)
389 // {_tpu_replicate = "cluster0", device = "", func = @_func} :
390 // (tensor<i1>) -> tensor<i1>
391 // tf_device.return %1 : tensor<i1>
392 // }
393 // return
394 // }
395 //
396 // will be rewritten as:
397 //
398 // func @main(%arg0: tensor<i1>, %arg1: tensor<i1>) {
399 // %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
400 // {n = 2 : i32, devices = ["/TPU:0", "/TPU:1"]} {
401 // %1 = "tf.Shape"(%ri) : (tensor<i1>) -> tensor<?xi32>
402 // %2:2 = "tf._TPUCompileMlir"(%1) {device = "/CPU:0"} :
403 // (tensor<?xi32>) -> (tensor<!tf.string>, tensor<!tf.string>)
404 // %3 = "tf.TPUExecute"(%ri, %2#0) :
405 // (tensor<i1>, tensor<!tf.string>) -> tensor<i1>
406 // tf_device.return %3 : tensor<i1>
407 // }
408 // return
409 // }
Rewrite(tf_device::LaunchFuncOp launch_func,llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,OpBuilder * builder)410 LogicalResult Rewrite(
411 tf_device::LaunchFuncOp launch_func,
412 llvm::ArrayRef<tensorflow::DeviceNameUtils::ParsedName> devices,
413 OpBuilder* builder) {
414 // Skip non-tpu device launch_func.
415 auto replicate_attr = launch_func.getAttrOfType<StringAttr>("_tpu_replicate");
416 if (!replicate_attr) return success();
417
418 // Collect `num_replicas` and `num_cores_per_replica` attributes.
419 int num_replicas = 1;
420 tf_device::ReplicateOp replicate =
421 launch_func.getParentOp()
422 ? llvm::dyn_cast_or_null<tf_device::ReplicateOp>(
423 launch_func.getParentOp())
424 : nullptr;
425 if (replicate) num_replicas = replicate.n().getLimitedValue();
426
427 auto num_cores_per_replica_attr =
428 launch_func.getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr);
429 if (!num_cores_per_replica_attr)
430 return launch_func.emitOpError(
431 CreateMissingAttributeMsg(kNumCoresPerReplicaAttr));
432
433 int num_cores_per_replica = num_cores_per_replica_attr.getInt();
434
435 // Determine compilation and execution devices.
436 std::string compilation_device;
437 llvm::SmallVector<std::string, 8> execution_devices;
438 auto status = tensorflow::GetTPUCompilationAndExecutionDevices(
439 devices, num_replicas, num_cores_per_replica, &compilation_device,
440 &execution_devices);
441 if (!status.ok())
442 return launch_func.emitError()
443 << "error in fetching TPU compilation/execution devices: "
444 << status.error_message();
445
446 // Create compile op;
447 builder->setInsertionPoint(launch_func);
448 Operation* compile_op =
449 BuildCompileOp(launch_func, num_replicas, num_cores_per_replica,
450 compilation_device, builder);
451 if (!compile_op) return failure();
452
453 // After rewrite, find if there is a TPUCompilationResultOp in the block with
454 // the same _tpu_replicate attribute and replace it with the result of the
455 // compile op. This op is used as a placeholder to hook during graph creation
456 // the other ops that are intended to consume the compile result.
457 Block* block = launch_func.getOperation()->getBlock();
458 for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
459 compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
460
461 BuildTPUCompileSucceededAssertOp(compile_op, builder);
462
463 // Create execute op.
464 Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder);
465 launch_func.replaceAllUsesWith(execute_op);
466 launch_func.erase();
467
468 // If computation is replicated, execution devices are assigned to the
469 // replicate. Otherwise there is only one execution device and the device is
470 // assigned to the execute op.
471 if (replicate) {
472 llvm::SmallVector<llvm::StringRef, 8> execution_device_refs(
473 execution_devices.begin(), execution_devices.end());
474 replicate.setAttr(kDevicesAttr,
475 builder->getStrArrayAttr(execution_device_refs));
476 } else {
477 execute_op->setAttr(kDeviceAttr,
478 builder->getStringAttr(execution_devices.front()));
479 }
480
481 return success();
482 }
483
runOnModule()484 void TPURewritePass::runOnModule() {
485 llvm::SmallVector<tensorflow::DeviceNameUtils::ParsedName, 8> devices;
486 if (failed(tensorflow::GetDevicesFromOp(getModule(), &devices)))
487 return signalPassFailure();
488
489 OpBuilder builder(&getContext());
490 auto result = getModule().walk([&](tf_device::LaunchFuncOp op) {
491 if (failed(Rewrite(op, devices, &builder))) return WalkResult::interrupt();
492
493 return WalkResult::advance();
494 });
495
496 if (result.wasInterrupted()) return signalPassFailure();
497
498 // Eliminate TPUCompilationResultOp now that the rewrite is complete.
499 getModule().walk([&](TF::TPUCompilationResultOp op) { op.erase(); });
500
501 // TODO(b/139377366): Remove functions that are no longer needed.
502 }
503
504 } // namespace
505
CreateTPURewritePass()506 std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass() {
507 return std::make_unique<TPURewritePass>();
508 }
509
510 static PassRegistration<TPURewritePass> pass(
511 "tf-tpu-rewrite",
512 "Rewriting `tf_device.launch_func` on TPUs into TPU runtime ops");
513
514 } // namespace TFTPU
515 } // namespace mlir
516