1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements lowering of TF dialect to TFRT CoreRuntime ExecuteOp.
17 // This lowering pass is heavily experimental and incomplete. External code
18 // should not depend on the code here. And please do not take example on it as
19 // "the path forward" for this.
20
21 #include <vector>
22
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Dialect.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Pass/PassManager.h"
30 #include "mlir/Pass/PassOptions.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "mlir/Transforms/Passes.h"
33 #include "mlir/Transforms/RegionUtils.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/iterator_range.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
49 #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h"
50 #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h"
51 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h"
52 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h"
53 #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h"
54 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h"
55 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
56 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
57 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
58 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
59 #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h"
60 #include "tensorflow/core/framework/tensor.h"
61 #include "tensorflow/core/framework/types.h"
62 #include "tensorflow/core/platform/tstring.h"
63 #include "tfrt/jitrt/opdefs/jitrt_ops.h" // from @tf_runtime
64 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
65 #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime
66 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
67 #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime
68 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
69 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
70 #include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
71 #include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime
72 #include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime
73
74 namespace tensorflow {
75 namespace {
76
77 constexpr unsigned kFallbackBenefit = 1;
78 constexpr unsigned kCoreRTBenefit = 2;
79 constexpr char kGpuDeviceName[] =
80 "/job:localhost/replica:0/task:0/device:GPU:0";
81 constexpr char kTFDeviceAttr[] = "tf.device";
82 constexpr char kTFRTDeviceAttr[] = "tfrt.device";
83 constexpr char kDeviceAttr[] = "device";
84 constexpr char kHostAttr[] = "host";
85 constexpr char kDeviceTypeTpu[] = "TPU";
86 constexpr int64_t kDefaultCheapCost = 1;
87
getDependentConversionDialects(mlir::DialectRegistry & registry)88 void getDependentConversionDialects(mlir::DialectRegistry ®istry) {
89 registry.insert<tfrt::corert::CoreRTDialect, mlir::func::FuncDialect,
90 tfrt::fallback_async::FallbackAsyncDialect,
91 tfrt::compiler::TFRTDialect, tfrt::dist::DistributedDialect,
92 tf_jitrt::JitRuntimeDialect>();
93 }
94
GetFunctionInputChain(mlir::Operation * op)95 mlir::Value GetFunctionInputChain(mlir::Operation *op) {
96 auto func_op = op->getParentOfType<mlir::func::FuncOp>();
97 return func_op.getArgument(0);
98 }
99
100 // Convert TF dialect ops to tfrt_fallback.executeop for non-side-effecting ops
101 // and tfrt_fallback.executeop.seq for side-effecting ops.
102 //
103 // For example,
104 //
105 // %0 = "tf.MatMul"(%arg0, %arg1) {device = "cpu"} :
106 // (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
107 //
108 // is converted to
109 //
110 // %result = tfrt_fallback.executeop device("cpu")
111 // "tf.MatMul"(%arg0, %arg1) : 1
112 //
113 class FallbackExecuteOpConversion : public mlir::ConversionPattern {
114 public:
FallbackExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,const tfrt_compiler::CostAnalysis * cost_analysis,bool tpu_lower_to_fallback,bool target_tpurt)115 FallbackExecuteOpConversion(
116 mlir::MLIRContext *context, CoreRTConverter *corert_converter,
117 tfrt_compiler::FallbackConverter *fallback_converter,
118 const tfrt_compiler::CostAnalysis *cost_analysis,
119 bool tpu_lower_to_fallback, bool target_tpurt)
120 : mlir::ConversionPattern(mlir::Pattern::MatchAnyOpTypeTag(),
121 kFallbackBenefit, context),
122 corert_converter_(*corert_converter),
123 fallback_converter_(*fallback_converter),
124 cost_analysis_(*cost_analysis),
125 tpu_lower_to_fallback_(tpu_lower_to_fallback),
126 target_tpurt_(target_tpurt) {}
127
matchAndRewrite(mlir::Operation * op,ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const128 LogicalResult matchAndRewrite(
129 mlir::Operation *op, ArrayRef<mlir::Value> operands,
130 mlir::ConversionPatternRewriter &rewriter) const override {
131 if (!UseFallback(op)) return failure();
132
133 if (target_tpurt_ && IsTpuCompileAndExecuteOps(op)) return failure();
134
135 corert_converter_.MaterializeDerivedAttributes(op);
136
137 mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
138 if (!device || device.getValue().empty())
139 return op->emitWarning("failed to find a non-empty 'device' attribute");
140 op->removeAttr(kDeviceAttr);
141 auto parsed_device_name =
142 corert_converter_.ParseDeviceName(device.getValue());
143 if (!parsed_device_name)
144 return op->emitWarning("failed to parse the device name");
145
146 // Convert the function (symbol) attributes to an array of string
147 // attributes, which represents the function names.
148 llvm::SmallVector<mlir::StringAttr, 4> func_attr_keys;
149 mlir::ArrayAttr op_func_attrs =
150 corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
151
152 // Remove the function attributes, which have already been processed.
153 for (const auto &key : func_attr_keys) op->removeAttr(key);
154
155 mlir::ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
156 if (!op_attrs) return op->emitWarning("failed to lower attributes.");
157
158 mlir::StringAttr op_name =
159 rewriter.getStringAttr(op->getName().getStringRef());
160
161 // Ops with _tpu_replicate attribute are TPU ops.
162 bool is_tpu_op = op->hasAttr("_tpu_replicate") ||
163 llvm::isa<mlir::TF::TPUReplicatedInputOp,
164 mlir::TF::TPUReplicatedOutputOp>(op);
165
166 if ((parsed_device_name->device_type == kDeviceTypeTpu &&
167 !tpu_lower_to_fallback_) ||
168 // Convert ops running on TPU to CoreRT dialect to prevent the creation
169 // of tfrt_fallback_async.createop for them.
170 // These ops will be encountered here only when using fallback to run
171 // TPU models, in which case, these ops are assumed to be in a function
172 // called by a TPUPartitionedCall op and will be compiled in
173 // TPUPartitionedCall op via FunctionLibraryRuntime and not be processed
174 // by BEFExecutor.
175 is_tpu_op) {
176 return ConvertToCoreRTExecuteOp(
177 op, operands, parsed_device_name->op_handler_name, op_attrs,
178 op_func_attrs, op_name, rewriter);
179 }
180
181 // For DEVICE_CPU, DEVICE_DEFAULT, and DEVICE_TPU_SYSTEM, we use fallback.
182 // Note that TPU bridge should handle all ops that are required to be
183 // executed on TPU. So if there are still ops that are placed on TPU at this
184 // stage of lowering TF to TFRT, then these ops are supposed to be executed
185 // on host.
186 return ConvertToFallbackExecuteOp(op, operands, device, op_attrs,
187 op_func_attrs, op_name,
188 fallback_converter_, rewriter);
189 }
190
191 private:
192 // Return true if this op can be lowered to fallback ops.
UseFallback(mlir::Operation * op) const193 bool UseFallback(mlir::Operation *op) const {
194 if (!llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) return false;
195
196 // Below is a blocklist of ops that should not go through CoreRTExecuteOp
197 // conversion.
198 // TODO(b/173017701): have a centralized place to hold the information
199 // whether a TF op should be lowered to FallbackExecute op.
200 return !llvm::isa<mlir::TF::_TfrtSetResourceOp,
201 mlir::TF::_TfrtGetResourceOp,
202 // Do not use fallback on the TPU fused
203 // compile_and_execute kernel.
204 mlir::TF::TPUCompileMlirAndExecuteOp,
205 // Specifically handle control flow ops.
206 mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp,
207 mlir::TF::StatefulPartitionedCallOp,
208 mlir::TF::PartitionedCallOp, mlir::TF::LegacyCallOp>(op);
209 }
210
IsTpuCompileAndExecuteOps(mlir::Operation * op) const211 bool IsTpuCompileAndExecuteOps(mlir::Operation *op) const {
212 return llvm::isa<mlir::TF::_TPUCompileMlirOp,
213 mlir::TF::TPUCompileSucceededAssertOp,
214 mlir::TF::TPUExecuteOp>(op);
215 }
216
217 mlir::LogicalResult ConvertToFallbackExecuteOp(
218 mlir::Operation *op, ValueRange operands, mlir::StringAttr device,
219 mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
220 mlir::StringAttr op_name,
221 tfrt_compiler::FallbackConverter &fallback_converter,
222 mlir::ConversionPatternRewriter &rewriter) const;
223
224 mlir::LogicalResult ConvertToCoreRTExecuteOp(
225 mlir::Operation *op, ValueRange operands, llvm::StringRef op_handler_name,
226 mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
227 mlir::StringAttr op_name,
228 mlir::ConversionPatternRewriter &rewriter) const;
229
230 CoreRTConverter &corert_converter_;
231 tfrt_compiler::FallbackConverter &fallback_converter_;
232 const tfrt_compiler::CostAnalysis &cost_analysis_;
233 bool tpu_lower_to_fallback_;
234 bool target_tpurt_;
235 };
236
ConvertToFallbackExecuteOp(mlir::Operation * op,ValueRange operands,mlir::StringAttr device,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,tfrt_compiler::FallbackConverter & fallback_converter,mlir::ConversionPatternRewriter & rewriter) const237 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp(
238 mlir::Operation *op, ValueRange operands, mlir::StringAttr device,
239 mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
240 mlir::StringAttr op_name,
241 tfrt_compiler::FallbackConverter &fallback_converter,
242 mlir::ConversionPatternRewriter &rewriter) const {
243 llvm::SmallVector<Type, 4> result_types(
244 op->getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
245
246 // Convert the operands to tfrt_fallback.tf_tensor if needed.
247 llvm::SmallVector<mlir::Value, 4> new_operands;
248 if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
249 op, device.getValue(), operands, &new_operands, rewriter)))
250 return failure();
251
252 auto fallback_key =
253 rewriter.getI64IntegerAttr(fallback_converter.GetNextFallbackKey());
254
255 // Query cost analysis to assign costs.
256 IntegerAttr cost;
257 auto parsed_device_name =
258 corert_converter_.ParseDeviceName(device.getValue());
259 if (parsed_device_name && parsed_device_name->device_type == DEVICE_GPU) {
260 // For GPU ops, the host only needs to dispatch them to GPUs, which should
261 // be relatively cheap for the host.
262 cost = rewriter.getI64IntegerAttr(kDefaultCheapCost);
263 } else {
264 cost = rewriter.getI64IntegerAttr(
265 cost_analysis_.GetCost(op, fallback_key.getInt()));
266 }
267
268 if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
269 auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
270 op->getLoc(), result_types, new_operands, device, op_attrs,
271 op_func_attrs, fallback_key, op_name, cost);
272 fallback_converter.RegisterFallbackOp(new_op);
273 rewriter.replaceOp(op, new_op.results());
274 } else {
275 auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
276 auto out_chain = in_chain;
277
278 if (tfrt_compiler::IsTensorArrayOp(op)) {
279 // If it is a tensor array op, we don't need to use
280 // tfrt_fallback_async.executeop.seq because its operands/results already
281 // take care of control dependencies.
282 auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
283 op->getLoc(), result_types, new_operands, device, op_attrs,
284 op_func_attrs, fallback_key, op_name, cost);
285 fallback_converter.RegisterFallbackOp(new_op);
286 rewriter.replaceOp(op, new_op.results());
287 } else {
288 // Create tfrt_fallback.executeop.seq if it is a side-effecting op.
289 auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOpSeq>(
290 op->getLoc(), corert_converter_.chain_type(), result_types, in_chain,
291 new_operands, device, op_attrs, op_func_attrs, fallback_key, op_name,
292 cost);
293 fallback_converter.RegisterFallbackOp(new_op);
294 rewriter.replaceOp(op, new_op.results());
295 out_chain = new_op.out_op_chain();
296 }
297
298 // Register the converted op so that it can be retrieved by successors.
299 corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
300 }
301
302 return success();
303 }
304
ConvertToCoreRTExecuteOp(mlir::Operation * op,ValueRange operands,llvm::StringRef op_handler_name,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,mlir::ConversionPatternRewriter & rewriter) const305 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToCoreRTExecuteOp(
306 mlir::Operation *op, ValueRange operands, llvm::StringRef op_handler_name,
307 mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
308 mlir::StringAttr op_name, mlir::ConversionPatternRewriter &rewriter) const {
309 llvm::SmallVector<Type, 4> result_types(
310 op->getNumResults(), rewriter.getType<tfrt::corert::TensorHandleType>());
311
312 // Convert the operands to tensorhandles if needed.
313 llvm::SmallVector<mlir::Value, 4> new_operands;
314 if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
315 op, operands, &new_operands, rewriter)))
316 return failure();
317
318 // Get the op handler, or create one if there does not exist one. Note that
319 // ConvertOpHandler changes internal state so it can only be called if the
320 // rewrite is guaranteed to succeed afterwards.
321 auto op_handler =
322 corert_converter_.ConvertOpHandler(op, op_handler_name, &rewriter);
323 if (!op_handler) return failure();
324
325 if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
326 auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
327 op->getLoc(), result_types, op_handler, new_operands, op_attrs,
328 op_func_attrs, op_name);
329 rewriter.replaceOp(op, new_op.results());
330 } else {
331 // Create corert.executeop.seq if it is a side-effecting op.
332 auto new_op = rewriter.create<tfrt::corert::ExecuteOpSeq>(
333 op->getLoc(), corert_converter_.chain_type(), result_types, op_handler,
334 corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
335 op_attrs, op_func_attrs, op_name);
336 rewriter.replaceOp(op, new_op.results());
337
338 // Register the converted op so that it can be retrieved by successors.
339 corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
340 }
341
342 return success();
343 }
344
345 class FallbackConstOpConversion
346 : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
347 public:
FallbackConstOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)348 FallbackConstOpConversion(mlir::MLIRContext *context,
349 CoreRTConverter *corert_converter)
350 : mlir::OpConversionPattern<mlir::TF::ConstOp>(context),
351 corert_converter_(*corert_converter) {}
352
matchAndRewrite(mlir::TF::ConstOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const353 mlir::LogicalResult matchAndRewrite(
354 mlir::TF::ConstOp op, OpAdaptor adaptor,
355 mlir::ConversionPatternRewriter &rewriter) const override {
356 // Some data types are handled separately using a fast path.
357 if (corert_converter_.IsSupportedNumericDType(op.dtype()) ||
358 op.dtype().isa<mlir::TF::StringType>())
359 return failure();
360
361 // For other data types that do not have a fast path (eg. quantized types),
362 // we convert it to serialized tensor proto.
363
364 tensorflow::TensorProto tensor_proto;
365 auto status = ConvertToTensorProto(op.value(), &tensor_proto);
366 if (!status.ok()) return op.emitError(status.error_message());
367
368 rewriter.replaceOpWithNewOp<tfrt::fallback_async::ConstTensorProtoOp>(
369 op, rewriter.getType<tfrt::fallback::TFTensorType>(),
370 tensor_proto.SerializeAsString());
371
372 return success();
373 }
374
375 private:
376 CoreRTConverter &corert_converter_;
377 };
378
379 class FallbackSetResourceOp
380 : public mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp> {
381 public:
FallbackSetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)382 FallbackSetResourceOp(mlir::MLIRContext *context,
383 CoreRTConverter *corert_converter)
384 : mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp>(context),
385 corert_converter_(*corert_converter) {}
386
matchAndRewrite(mlir::TF::_TfrtSetResourceOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const387 mlir::LogicalResult matchAndRewrite(
388 mlir::TF::_TfrtSetResourceOp op, OpAdaptor adaptor,
389 mlir::ConversionPatternRewriter &rewriter) const override {
390 mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
391 if (!device || device.getValue().empty())
392 return op->emitWarning("failed to find a non-empty 'device' attribute");
393
394 // Currently the static resource is always on host CPU.
395 //
396 // TODO(chky): Support resource on other devices.
397 llvm::SmallVector<mlir::Value, 4> new_operands;
398 if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
399 op, tfrt_compiler::GetDefaultCpuDeviceName(), adaptor.getOperands(),
400 &new_operands, rewriter)))
401 return failure();
402
403 assert(new_operands.size() == 1);
404
405 auto new_op = rewriter.create<tfrt::fallback_async::SetResourceOp>(
406 op.getLoc(), corert_converter_.chain_type(),
407 corert_converter_.GetLocalSideEffectChain(op, &rewriter),
408 new_operands[0], device.getValue(), op.index());
409
410 // Register the converted op so that it can be retrieved by successors.
411 corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_ch());
412
413 rewriter.eraseOp(op);
414
415 return success();
416 }
417
418 private:
419 CoreRTConverter &corert_converter_;
420 };
421
422 class FallbackGetResourceOp
423 : public mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp> {
424 public:
FallbackGetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)425 FallbackGetResourceOp(mlir::MLIRContext *context,
426 CoreRTConverter *corert_converter)
427 : mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp>(context),
428 corert_converter_(*corert_converter) {}
429
matchAndRewrite(mlir::TF::_TfrtGetResourceOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const430 mlir::LogicalResult matchAndRewrite(
431 mlir::TF::_TfrtGetResourceOp op, OpAdaptor adaptor,
432 mlir::ConversionPatternRewriter &rewriter) const override {
433 mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
434 if (!device || device.getValue().empty())
435 return op->emitWarning("failed to find a non-empty 'device' attribute");
436
437 llvm::SmallVector<mlir::Type, 4> result_types(
438 op.getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
439
440 auto ready_chain = rewriter.create<tfrt::compiler::NewChainOp>(
441 op.getLoc(), rewriter.getType<tfrt::compiler::ChainType>());
442
443 auto new_op = rewriter.create<tfrt::fallback_async::GetResourceOp>(
444 op.getLoc(), corert_converter_.chain_type(), result_types, ready_chain,
445 device.getValue(), op.indices());
446
447 rewriter.replaceOp(op, new_op.results());
448
449 return success();
450 }
451
452 private:
453 CoreRTConverter &corert_converter_;
454 };
455
456 // Convert a tf_device.remote_run op to a tfrt_dist.remote_execute_func op.
457 //
458 // For example,
459 //
460 // %result = tf_device.remote_run "/job:worker/replica:0/task:0"
461 // @remote_func(%arg0) : (tensor<i32>) -> (tensor<i32>)
462 //
463 // is converted to
464 //
465 // %0 = tfrt_dist.get_distributed_context
466 // %1 = tfrt_dist.get_task_handle %0
467 // {task_name = "/job:worker/replica:0/task:0"}
468 // %2 = tfrt_dist.test_create_remote_chain_manager %0
469 // %3 = tfrt_dist.get_chain_for_task_handle %in_chain, %2, %1
470 // %out_op_chain, %results:2 = tfrt_dist.remote_execute_func[%in_chain, %0, %1]
471 // @remote_func(%3, %arg1): (!tfrt_dist.remote_object_id, !corert.tensorhandle)
472 // -> (!tfrt_dist.remote_object_id, !corert.tensorhandle)
473 // %4 = tfrt_dist.set_chain_for_task_handle %out_op_chain, %2, %1, %results#0
474 //
475 class TFDeviceRemoteRunOpConversion
476 : public mlir::OpConversionPattern<tf_device::RemoteRunOp> {
477 public:
TFDeviceRemoteRunOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter)478 TFDeviceRemoteRunOpConversion(mlir::MLIRContext *context,
479 mlir::TypeConverter *type_converter,
480 CoreRTConverter *corert_converter)
481 : mlir::OpConversionPattern<tf_device::RemoteRunOp>(context,
482 kCoreRTBenefit),
483 type_converter_(*type_converter),
484 corert_converter_(*corert_converter) {}
485
matchAndRewrite(tf_device::RemoteRunOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const486 LogicalResult matchAndRewrite(
487 tf_device::RemoteRunOp op, OpAdaptor adaptor,
488 ConversionPatternRewriter &rewriter) const override {
489 mlir::Value distributed_context =
490 corert_converter_.GetDistributedContext(op.getOperation(), &rewriter);
491 mlir::Value in_op_chain =
492 corert_converter_.GetLocalSideEffectChain(op, &rewriter);
493 mlir::Value task_handle = corert_converter_.GetTaskHandle(
494 op.getOperation(), op.host(), &rewriter);
495 mlir::Value remote_chain_mgr =
496 corert_converter_.GetRemoteChainManager(op, &rewriter);
497 mlir::Type remote_obj_id_ty =
498 rewriter.getType<tfrt::dist::RemoteObjectIdType>();
499 ModuleOp module = op->getParentOfType<ModuleOp>();
500 SymbolTable symtab(module);
501 func::FuncOp callee = symtab.lookup<func::FuncOp>(op.callee());
502 if (!callee) {
503 op.emitOpError("callee function ") << op.callee() << " is not found";
504 return failure();
505 }
506 StringAttr host = callee->getAttrOfType<StringAttr>(kHostAttr);
507 if (!host) {
508 op.emitOpError("callee function ")
509 << op.callee() << " should have the host attribute";
510 return failure();
511 }
512
513 llvm::SmallVector<mlir::Value, 4> arguments;
514 // The first argument of the remote function should be a remote chain which
515 // is added to the function signature when it is lowered from TF dialect to
516 // TFRT dialect.
517 arguments.push_back(corert_converter_.GetRemoteSideEffectChain(
518 op, host.getValue(), &rewriter));
519 for (mlir::Value argument : op.callee_args()) {
520 arguments.push_back(argument);
521 }
522
523 llvm::SmallVector<mlir::Type, 4> result_types;
524 // The first result of the remote function should be a remote chain which
525 // is added to the function signature when it is lowered from TF dialect to
526 // TFRT dialect.
527 result_types.push_back(remote_obj_id_ty);
528 for (mlir::Type type : op.getResultTypes()) {
529 (void)type_converter_.convertType(type, result_types);
530 }
531 auto remote_execute_func_op =
532 rewriter.create<tfrt::dist::RemoteExecuteFuncOp>(
533 op.getLoc(), corert_converter_.chain_type(), result_types,
534 in_op_chain, distributed_context, task_handle, op.callee(),
535 arguments);
536 rewriter.replaceOp(op, remote_execute_func_op.results().drop_front(1));
537
538 auto set_chain_op = rewriter.create<tfrt::dist::SetChainForTaskHandleOp>(
539 op.getLoc(), corert_converter_.chain_type(),
540 remote_execute_func_op.out_op_chain(), remote_chain_mgr, task_handle,
541 remote_execute_func_op.results().front());
542 corert_converter_.RegisterLocalSideEffectChain(op,
543 set_chain_op.out_op_chain());
544
545 return success();
546 }
547
548 private:
549 mlir::TypeConverter &type_converter_;
550 CoreRTConverter &corert_converter_;
551 };
552
553 // Lowers a tf.BatchFunction to tfrt_fallback.batch_function.
554 class FallbackBatchFunctionOpConversion
555 : public mlir::OpConversionPattern<mlir::TF::BatchFunctionOp> {
556 public:
FallbackBatchFunctionOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)557 FallbackBatchFunctionOpConversion(mlir::MLIRContext *context,
558 CoreRTConverter *corert_converter)
559 : mlir::OpConversionPattern<mlir::TF::BatchFunctionOp>(context,
560 kFallbackBenefit),
561 corert_converter_(*corert_converter) {}
562
matchAndRewrite(mlir::TF::BatchFunctionOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const563 LogicalResult matchAndRewrite(
564 mlir::TF::BatchFunctionOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const override {
566 corert_converter_.MaterializeDerivedAttributes(op);
567
568 // Remove the device attribute for fallback, as currently fallback will
569 // select device automatically.
570 //
571 // TODO(chky): The device attribute should be passed explicitly. This can be
572 // once we change the kernel implementation to choose device based on
573 // attributes.
574 op->removeAttr(rewriter.getStringAttr(kDeviceAttr));
575
576 SymbolRefAttr f = op.fAttr();
577
578 llvm::SmallVector<NamedAttribute, 12> attr_array;
579 for (auto &key_and_value : op->getAttrs()) {
580 StringRef name = key_and_value.getName();
581 if (name == "_output_shapes" || name == "f") {
582 continue;
583 }
584 attr_array.push_back(key_and_value);
585 }
586 ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(attr_array);
587 if (!op_attrs) return op.emitWarning("failed to lower attributes.");
588
589 llvm::SmallVector<Type, 4> result_types;
590 for (auto type : op.getResultTypes()) {
591 if (failed(corert_converter_.convertType(type, result_types)))
592 return failure();
593 }
594
595 llvm::SmallVector<mlir::Value, 4> new_operands;
596 if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
597 op, adaptor.getOperands(), &new_operands, rewriter)))
598 return failure();
599
600 auto new_op = rewriter.create<tfrt::fallback_async::BatchFunctionOp>(
601 op.getLoc(), corert_converter_.chain_type(), result_types,
602 corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
603 f, op_attrs);
604 rewriter.replaceOp(op, new_op.results());
605
606 // Register the converted op so that it can be retrieved by successors.
607 corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
608
609 return success();
610 }
611
612 private:
613 CoreRTConverter &corert_converter_;
614 };
615
616 // Lower a tf.Const op that creates a string tensor to a native
617 // corert.create_string_tensor op.
618 class CoreRTConstDenseTensorOpConversion
619 : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
620 public:
CoreRTConstDenseTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)621 CoreRTConstDenseTensorOpConversion(mlir::MLIRContext *context,
622 CoreRTConverter *corert_converter)
623 : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
624 corert_converter_(*corert_converter) {}
625
matchAndRewrite(mlir::TF::ConstOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const626 LogicalResult matchAndRewrite(
627 mlir::TF::ConstOp op, OpAdaptor adaptor,
628 ConversionPatternRewriter &rewriter) const override {
629 if (!corert_converter_.IsSupportedNumericDType(op.dtype()))
630 return failure();
631
632 // Only CPU ops can be lowered using this conversion. If there is no device
633 // assignment, this op is treated as a CPU op and can be lowered.
634 if (auto parsed_device_name = corert_converter_.ParseDeviceName(op))
635 if (parsed_device_name->device_type != DEVICE_CPU) return failure();
636
637 auto new_op = rewriter.create<tfrt::corert::ConstDenseTensorOp>(
638 op.getLoc(), corert_converter_.tensor_handle_type(),
639 op.value().cast<DenseElementsAttr>());
640 rewriter.replaceOp(op, new_op->getResult(0));
641 return success();
642 }
643
644 private:
645 CoreRTConverter &corert_converter_;
646 };
647
648 // Convert the FuncOp with the following changes to meet TFRT's requirements:
649 // 1) Convert types for the arguments and the results.
650 // 2) Add a chain to the arguments.
651 // 3) Add a chain to the results for side-effects.
652 // 4) If any argument has a tf.device attribute, change the attribute name
653 // to tfrt.device.
654 // 5) If any result has a tf.device attribute, change the attribute name
655 // to tfrt.device.
656 //
657 // The input chain is used to signal visibility of all side-effects before
658 // calling this function. The output chain is used to signal visibility of all
659 // side-effects of this function.
660 class TFRTFuncOpSignatureConversion
661 : public mlir::OpConversionPattern<mlir::func::FuncOp> {
662 public:
TFRTFuncOpSignatureConversion(mlir::MLIRContext * ctx,mlir::TypeConverter * type_converter)663 TFRTFuncOpSignatureConversion(mlir::MLIRContext *ctx,
664 mlir::TypeConverter *type_converter)
665 : OpConversionPattern(ctx), type_converter_(*type_converter) {}
666
matchAndRewrite(mlir::func::FuncOp func_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const667 LogicalResult matchAndRewrite(
668 mlir::func::FuncOp func_op, OpAdaptor adaptor,
669 ConversionPatternRewriter &rewriter) const override {
670 mlir::FunctionType type = func_op.getFunctionType();
671
672 // Convert the original function arguments.
673 mlir::TypeConverter::SignatureConversion converted_signature(
674 type.getNumInputs());
675 // Add a chain as the first input.
676 converted_signature.addInputs(
677 rewriter.getType<tfrt::compiler::ChainType>());
678
679 // Convert the original function results.
680 SmallVector<Type, 1> converted_results;
681 // Add a chain as the first result.
682 converted_results.push_back(rewriter.getType<tfrt::compiler::ChainType>());
683
684 if (failed(type_converter_.convertSignatureArgs(type.getInputs(),
685 converted_signature)) ||
686 failed(type_converter_.convertTypes(type.getResults(),
687 converted_results)) ||
688 failed(rewriter.convertRegionTypes(&func_op.getBody(), type_converter_,
689 &converted_signature))) {
690 return failure();
691 }
692
693 llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs;
694 // The first input, which is a chain added by this pass, has no attribute.
695 arg_attrs.emplace_back();
696 func_op.getAllArgAttrs(arg_attrs);
697 // If any argument has a tf.device attribute, change the attribute name to
698 // tfrt.device.
699 for (auto &arg_attr : arg_attrs) {
700 mlir::NamedAttrList arg_attr_list(arg_attr);
701 if (Attribute device = arg_attr_list.erase(kTFDeviceAttr)) {
702 arg_attr_list.set(kTFRTDeviceAttr, device);
703 arg_attr = arg_attr_list.getDictionary(device.getContext());
704 }
705 }
706 arg_attrs.resize(converted_signature.getConvertedTypes().size());
707
708 // The first result, which is a chain added by this pass, has no attribute.
709 llvm::SmallVector<mlir::DictionaryAttr, 4> result_attrs;
710 result_attrs.emplace_back();
711 func_op.getAllResultAttrs(result_attrs);
712 // If any result has a tf.device attribute, change the attribute name to
713 // tfrt.device.
714 for (auto &result_attr : result_attrs) {
715 mlir::NamedAttrList result_attr_list(result_attr);
716 if (Attribute device = result_attr_list.erase(kTFDeviceAttr)) {
717 result_attr_list.set(kTFRTDeviceAttr, device);
718 result_attr = result_attr_list.getDictionary(device.getContext());
719 }
720 }
721 result_attrs.resize(converted_results.size());
722
723 // Update the function signature in-place.
724 rewriter.updateRootInPlace(func_op, [&] {
725 func_op.setType(mlir::FunctionType::get(
726 func_op.getContext(), converted_signature.getConvertedTypes(),
727 converted_results));
728 func_op.setAllArgAttrs(arg_attrs);
729 func_op.setAllResultAttrs(result_attrs);
730 });
731
732 return success();
733 }
734
735 private:
736 mlir::TypeConverter &type_converter_;
737 };
738
739 // Lower a tf.Const op that creates a string tensor to a native
740 // corert.create_string_tensor op.
741 class CoreRTConstStringTensorOpConversion
742 : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
743 public:
CoreRTConstStringTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)744 CoreRTConstStringTensorOpConversion(mlir::MLIRContext *context,
745 CoreRTConverter *corert_converter)
746 : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
747 corert_converter_(*corert_converter) {}
748
matchAndRewrite(mlir::TF::ConstOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const749 LogicalResult matchAndRewrite(
750 mlir::TF::ConstOp op, OpAdaptor adaptor,
751 ConversionPatternRewriter &rewriter) const override { // NOLINT
752 if (!op.dtype().isa<mlir::TF::StringType>()) return failure();
753
754 DenseStringElementsAttr attr = op.value().cast<DenseStringElementsAttr>();
755
756 llvm::SmallVector<Attribute, 4> values;
757 values.reserve(attr.getNumElements());
758 for (const auto &element : attr.getRawStringData())
759 values.push_back(rewriter.getStringAttr(
760 llvm::StringRef(element.data(), element.size())));
761
762 // Create the shape attribute from the tensor shape.
763 ArrayRef<int64_t> shape = op.value().getType().getShape();
764 llvm::SmallVector<mlir::Attribute, 4> dims;
765 dims.reserve(shape.size());
766 auto i64_type = rewriter.getIntegerType(64);
767 for (auto dim : shape)
768 dims.push_back(rewriter.getIntegerAttr(i64_type, dim));
769
770 auto new_op = rewriter.create<tfrt::corert::ConstStringTensorOp>(
771 op.getLoc(), corert_converter_.tensor_handle_type(),
772 rewriter.getArrayAttr(dims), rewriter.getArrayAttr(values));
773
774 rewriter.replaceOp(op, new_op.result());
775
776 return success();
777 }
778
779 private:
780 CoreRTConverter &corert_converter_;
781 };
782
783 // Convert TF dialect operations with no side effects to CoreRT ExecuteOp. For
784 // example,
785 //
786 // %0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
787 // (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
788 //
789 // is converted to
790 //
791 // %result = corert.executeop(%device)
792 // "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
793 // (!corert.tensorhandle, !corert.tensorhandle) -> !corert.tensorhandle
794 //
795 // Note that it will fail to match if some attributes are not supported.
796 template <typename TF_Op>
797 class CoreRTExecuteOpConversion : public mlir::OpConversionPattern<TF_Op> {
798 public:
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)799 CoreRTExecuteOpConversion(mlir::MLIRContext *context,
800 CoreRTConverter *corert_converter)
801 : CoreRTExecuteOpConversion(context, corert_converter, "") {}
802
803 // If device_name is not empty, only ops that are using this device is lowered
804 // using CoreRTExecuteOpConversion.
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,llvm::StringRef device_name)805 CoreRTExecuteOpConversion(mlir::MLIRContext *context,
806 CoreRTConverter *corert_converter,
807 llvm::StringRef device_name)
808 : mlir::OpConversionPattern<TF_Op>(context, kCoreRTBenefit),
809 corert_converter_(*corert_converter),
810 device_name_(device_name) {}
811
matchAndRewrite(TF_Op op,typename TF_Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const812 LogicalResult matchAndRewrite(
813 TF_Op op, typename TF_Op::Adaptor adaptor,
814 ConversionPatternRewriter &rewriter) const override {
815 auto parsed_device_name = corert_converter_.ParseDeviceName(op);
816 // Return failure and emit warning if there is no device assignment.
817 if (!parsed_device_name) {
818 return op->emitWarning(
819 "failed to retrieve valid device when converting to "
820 "corert.executeop");
821 }
822
823 // If device_name is specified, check the device of this op first.
824 if (!device_name_.empty()) {
825 // Skip if it does not match the specified device.
826 if (parsed_device_name->device_name != device_name_) return failure();
827 }
828
829 mlir::StringAttr op_name = rewriter.getStringAttr(op.getOperationName());
830
831 llvm::SmallVector<Type, 4> result_types;
832 for (auto type : op.getOperation()->getResultTypes()) {
833 if (failed(corert_converter_.convertType(type, result_types)))
834 return failure();
835 }
836
837 corert_converter_.MaterializeDerivedAttributes(op);
838
839 // Convert the function (symbol) attributes to an array of string
840 // attributes, which represents the function names.
841 llvm::SmallVector<mlir::StringAttr, 4> func_attr_keys;
842 ArrayAttr op_func_attrs =
843 corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
844
845 // Remove the function attributes, which have already been processed.
846 for (const auto &key : func_attr_keys) op->removeAttr(key);
847
848 ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
849 if (!op_attrs) return op.emitError("failed to lower attributes.");
850
851 llvm::SmallVector<mlir::Value, 4> new_operands;
852 if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
853 op, adaptor.getOperands(), &new_operands, rewriter)))
854 return failure();
855
856 // Get the op handler, or create one if there does not exist one. Note that
857 // ConvertOpHandler changes internal state so it can only be called if the
858 // rewrite is guaranteed to succeed afterwards.
859 auto op_handler = corert_converter_.ConvertOpHandler(
860 op, parsed_device_name->op_handler_name, &rewriter);
861 if (!op_handler) return failure();
862
863 auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
864 op.getLoc(), result_types, op_handler, new_operands, op_attrs,
865 op_func_attrs, op_name);
866
867 rewriter.replaceOp(op, new_op.results());
868 return success();
869 }
870
871 private:
872 CoreRTConverter &corert_converter_;
873 llvm::StringRef device_name_;
874 };
875
ConvertFunctionCallOperands(mlir::Operation * op,ValueRange operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter,bool func_use_fallback_tensor)876 LogicalResult ConvertFunctionCallOperands(
877 mlir::Operation *op, ValueRange operands,
878 llvm::SmallVectorImpl<mlir::Value> *new_operands,
879 mlir::ConversionPatternRewriter &rewriter, bool func_use_fallback_tensor) {
880 if (func_use_fallback_tensor) {
881 // TODO(b/182232457): Support other devices.
882 return tfrt_compiler::ConvertFallbackOperands(
883 op, tfrt_compiler::GetDefaultCpuDeviceName(), operands, new_operands,
884 rewriter);
885 } else {
886 return tfrt_compiler::ConvertCoreRTOperands(op, operands, new_operands,
887 rewriter);
888 }
889 }
890
891 // Convert TF call ops (eg. StatefulPartitionedCall) to tfrt.call.
892 template <typename CallOp>
893 class TFRTCallOpConversion : public mlir::OpConversionPattern<CallOp> {
894 public:
TFRTCallOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)895 TFRTCallOpConversion(mlir::MLIRContext *context,
896 mlir::TypeConverter *type_converter,
897 CoreRTConverter *corert_converter,
898 bool func_use_fallback_tensor)
899 : mlir::OpConversionPattern<CallOp>(context),
900 type_converter_(*type_converter),
901 corert_converter_(*corert_converter),
902 func_use_fallback_tensor_(func_use_fallback_tensor) {}
903
matchAndRewrite(CallOp op,typename CallOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const904 LogicalResult matchAndRewrite(
905 CallOp op, typename CallOp::Adaptor adaptor,
906 ConversionPatternRewriter &rewriter) const override {
907 auto callee =
908 op.getCallableForCallee().template dyn_cast<mlir::SymbolRefAttr>();
909 if (!callee) return failure();
910
911 llvm::SmallVector<mlir::Value, 4> new_operands;
912 new_operands.push_back(
913 corert_converter_.GetLocalSideEffectChain(op, &rewriter));
914 // Currently the converted functions always use !corert.tensorhandle types
915 // for tensor arguments and results.
916 //
917 // TODO(b/175881042): We should avoid the tensor conversion here if the
918 // operand is !tfrt_fallback.tf_tensor, and it is also used as fallback
919 // tensor inside the callee function.
920 if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(),
921 &new_operands, rewriter,
922 func_use_fallback_tensor_)))
923 return failure();
924
925 llvm::SmallVector<mlir::Type, 4> result_types;
926 result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
927 for (auto type : op.getOperation()->getResultTypes()) {
928 if (failed(type_converter_.convertType(type, result_types)))
929 return failure();
930 }
931
932 auto new_op = rewriter.create<tfrt::compiler::CallOp>(
933 op.getLoc(), result_types, callee.getRootReference().getValue(),
934 new_operands);
935 rewriter.replaceOp(op, new_op.getResults().drop_front());
936
937 if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
938 // Register the converted op so that it can be retrieved by successors.
939 // TODO(chky): Add OpTraits or OpInterface, rather than assume first
940 // result is a chain.
941 corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
942 }
943
944 return success();
945 }
946
947 private:
948 mlir::TypeConverter &type_converter_;
949 CoreRTConverter &corert_converter_;
950 bool func_use_fallback_tensor_;
951 };
952
953 // Convert func ReturnOp to tfrt.return.
954 //
955 // TODO(chky): conversion to tfrt kernels should come from a common tf_to_tfrt
956 // library.
957 class TFRTReturnOpConversion
958 : public mlir::OpConversionPattern<mlir::func::ReturnOp> {
959 public:
TFRTReturnOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)960 TFRTReturnOpConversion(mlir::MLIRContext *context,
961 CoreRTConverter *corert_converter,
962 bool func_use_fallback_tensor)
963 : mlir::OpConversionPattern<mlir::func::ReturnOp>(context),
964 corert_converter_(*corert_converter),
965 func_use_fallback_tensor_(func_use_fallback_tensor) {}
966
matchAndRewrite(mlir::func::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const967 LogicalResult matchAndRewrite(
968 mlir::func::ReturnOp op, OpAdaptor adaptor,
969 ConversionPatternRewriter &rewriter) const override {
970 llvm::SmallVector<mlir::Value, 2> new_operands;
971
972 // Currently in mlir::TF::SideEffectAnalysis, all terminator ops are treated
973 // as side-effect ops and they have predecessors but not successors.
974 //
975 // TODO(chky): ReturnOp has no side effect. Once the special handling in
976 // mlir::TF::SideEffectAnalysis is removed, the chains should come from
977 // side-effecting ops with no successors in the function.
978 new_operands.push_back(
979 corert_converter_.GetLocalSideEffectChain(op, &rewriter));
980 if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(),
981 &new_operands, rewriter,
982 func_use_fallback_tensor_)))
983 return failure();
984
985 rewriter.replaceOpWithNewOp<tfrt::compiler::ReturnOp>(op, new_operands);
986 return success();
987 }
988
989 private:
990 CoreRTConverter &corert_converter_;
991 bool func_use_fallback_tensor_;
992 };
993
994 // Convert tf.Case op to tfrt.Case.
995 //
996 // TF dialect:
997 // %outputs = "tf.Case"(%arg, ...) { branches = [@branch0, @branch1], ...}
998 //
999 // lowered TFRT CoreRT dialect:
1000 // %idx_int = corert.tensorhandle_to_int32 %idx
1001 // %out_chain, %outputs = tfrt.case %idx_int [@branch0, @branch1] (%chain, ...)
1002 class TFRTCaseOpConversion : public mlir::OpConversionPattern<TF::CaseOp> {
1003 public:
TFRTCaseOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)1004 TFRTCaseOpConversion(mlir::MLIRContext *context,
1005 mlir::TypeConverter *type_converter,
1006 CoreRTConverter *corert_converter,
1007 bool func_use_fallback_tensor)
1008 : mlir::OpConversionPattern<TF::CaseOp>(context),
1009 type_converter_(*type_converter),
1010 corert_converter_(*corert_converter),
1011 func_use_fallback_tensor_(func_use_fallback_tensor) {}
1012
matchAndRewrite(TF::CaseOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1013 LogicalResult matchAndRewrite(
1014 TF::CaseOp op, OpAdaptor adaptor,
1015 ConversionPatternRewriter &rewriter) const override {
1016 mlir::ArrayAttr branches = op.branches();
1017
1018 llvm::SmallVector<mlir::Type, 4> result_types;
1019 result_types.push_back(corert_converter_.chain_type());
1020 for (mlir::Type type : op->getResultTypes()) {
1021 if (failed(type_converter_.convertType(type, result_types)))
1022 return failure();
1023 }
1024
1025 llvm::SmallVector<mlir::Value, 4> branch_operands;
1026 branch_operands.push_back(
1027 corert_converter_.GetLocalSideEffectChain(op, &rewriter));
1028 if (mlir::failed(ConvertFunctionCallOperands(
1029 op, adaptor.getOperands().drop_front(), &branch_operands, rewriter,
1030 func_use_fallback_tensor_)))
1031 return failure();
1032
1033 mlir::Value index_operand = adaptor.getOperands()[0];
1034 // TODO(b/182233401): Support TF tensor; remove the conversion op here.
1035 if (index_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
1036 // TODO(b/182232457): Support other devices.
1037 index_operand =
1038 rewriter
1039 .create<
1040 tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
1041 op.getLoc(),
1042 rewriter.getType<tfrt::corert::TensorHandleType>(),
1043 adaptor.getOperands()[0],
1044 tfrt_compiler::GetDefaultCpuDeviceName())
1045 .getResult(0);
1046 }
1047 if (!index_operand.getType().isa<tfrt::corert::TensorHandleType>())
1048 return op.emitError(
1049 "branch index operand is expected to be a TensorHandle.");
1050 mlir::Value index_value =
1051 rewriter.create<tfrt::corert::TensorHandleToIntOp>(
1052 op.getLoc(), rewriter.getI32Type(), index_operand);
1053
1054 auto new_op = rewriter.create<tfrt::compiler::CaseOp>(
1055 op.getLoc(), result_types, index_value, branches, branch_operands);
1056
1057 rewriter.replaceOp(op, new_op.branch_outputs().drop_front());
1058 return success();
1059 }
1060
1061 private:
1062 mlir::TypeConverter &type_converter_;
1063 CoreRTConverter &corert_converter_;
1064 bool func_use_fallback_tensor_;
1065 };
1066
GetPredicate(mlir::Operation * op,mlir::Value cond_operand,mlir::ConversionPatternRewriter & rewriter)1067 static mlir::Value GetPredicate(mlir::Operation *op, mlir::Value cond_operand,
1068 mlir::ConversionPatternRewriter &rewriter) {
1069 if (!cond_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
1070 cond_operand = tfrt_compiler::ConvertCoreRTTensorHandleToFallbackTensor(
1071 op->getLoc(), tfrt_compiler::GetDefaultCpuDeviceName(), cond_operand,
1072 rewriter);
1073 if (!cond_operand) {
1074 op->emitWarning("failed to convert the cond operand to fallback tensor.");
1075 return {};
1076 }
1077 }
1078
1079 return rewriter.create<tfrt::fallback_async::PredicateOp>(
1080 op->getLoc(), rewriter.getI1Type(), cond_operand);
1081 }
1082
1083 class TFRTCondOpConversion : public mlir::OpConversionPattern<mlir::TF::IfOp> {
1084 public:
TFRTCondOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)1085 TFRTCondOpConversion(mlir::MLIRContext *context,
1086 mlir::TypeConverter *type_converter,
1087 CoreRTConverter *corert_converter,
1088 bool func_use_fallback_tensor)
1089 : mlir::OpConversionPattern<TF::IfOp>(context),
1090 type_converter_(*type_converter),
1091 corert_converter_(*corert_converter),
1092 func_use_fallback_tensor_(func_use_fallback_tensor) {}
1093
matchAndRewrite(mlir::TF::IfOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1094 mlir::LogicalResult matchAndRewrite(
1095 mlir::TF::IfOp op, OpAdaptor adaptor,
1096 mlir::ConversionPatternRewriter &rewriter) const override {
1097 mlir::FlatSymbolRefAttr then_branch = op.then_branchAttr();
1098 mlir::FlatSymbolRefAttr else_branch = op.else_branchAttr();
1099
1100 llvm::SmallVector<Type, 4> result_types;
1101 result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
1102 for (mlir::Type type : op.getOperation()->getResultTypes()) {
1103 if (failed(type_converter_.convertType(type, result_types)))
1104 return failure();
1105 }
1106
1107 // Convert the cond tensor to a boolean value so that it can be used by
1108 // tfrt.cond kernel.
1109 auto bool_cond = GetPredicate(op, adaptor.getOperands()[0], rewriter);
1110 if (!bool_cond) return failure();
1111
1112 llvm::SmallVector<mlir::Value, 4> new_operands;
1113 // The first arg of converted branch functions should be !tfrt.chain.
1114 new_operands.push_back(
1115 corert_converter_.GetLocalSideEffectChain(op, &rewriter));
1116
1117 if (mlir::failed(ConvertFunctionCallOperands(
1118 op, adaptor.getOperands().drop_front(), &new_operands, rewriter,
1119 func_use_fallback_tensor_)))
1120 return failure();
1121
1122 auto new_op = rewriter.create<tfrt::compiler::CondOp>(
1123 op.getLoc(), result_types, bool_cond, then_branch, else_branch,
1124 new_operands);
1125
1126 // The first result is a !tfrt.chain.
1127 rewriter.replaceOp(op, new_op.getResults().drop_front(1));
1128
1129 if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1130 // Register the converted op so that it can be retrieved by successors.
1131 // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1132 // result is a chain.
1133 corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
1134 }
1135
1136 return success();
1137 }
1138
1139 private:
1140 mlir::TypeConverter &type_converter_;
1141 CoreRTConverter &corert_converter_;
1142 bool func_use_fallback_tensor_;
1143 };
1144
1145 // Convert TF WhileOp to tfrt.while. tfrt.while use a boolean condition and has
1146 // slightly different semantics from tf.While for performance and generality.
1147 // The pseudo code of tfrt.while is as follows:
1148 //
1149 // while(cond) {
1150 // outputs, cond = body(inputs)
1151 // inputs = outputs
1152 // }
1153 // return outputs
1154 //
1155 // So we need to insert extra convertion kernels and merge functions when
1156 // lowering tf.While to tfrt.while.
1157 //
1158 // %result = tf.While(%arg) {cond = @original_cond_fn, body =
1159 // @original_body_fn}
1160 //
1161 // is converted to
1162 //
1163 // func @new_pred_fn(%ch, %arg) {
1164 // %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1165 // %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1166 // tfrt.return %ch0, %cond_bool
1167 // }
1168 //
1169 // func @new_while_body(%ch, %arg) {
1170 // %ch0, %result = tfrt.call @original_body_fn(%ch, %arg)
1171 // %ch1, %cond_bool = tfrt.call @new_pred_fn(%ch0, %result)
1172 // tfrt.return %ch1, %result, %cond_bool
1173 // }
1174 //
1175 // %ch0, %first_iter_cond = tfrt.call @new_pred_fn(%ch, %arg)
1176 // %ch1, %result = tfrt.while %first_iter_cond @new_while_body(%ch0, %arg)
1177 //
1178 class TFRTWhileOpConversion
1179 : public mlir::OpConversionPattern<mlir::TF::WhileOp> {
1180 public:
TFRTWhileOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool func_use_fallback_tensor,bool enable_while_parallel_iterations)1181 TFRTWhileOpConversion(mlir::MLIRContext *context,
1182 mlir::TypeConverter *type_converter,
1183 CoreRTConverter *corert_converter,
1184 mlir::SymbolTable *symbol_table,
1185 const tfrt_compiler::TensorArraySideEffectAnalysis
1186 *tensor_array_side_effect_analysis,
1187 bool func_use_fallback_tensor,
1188 bool enable_while_parallel_iterations)
1189 : mlir::OpConversionPattern<TF::WhileOp>(context),
1190 type_converter_(*type_converter),
1191 corert_converter_(*corert_converter),
1192 symbol_table_(*symbol_table),
1193 tensor_array_side_effect_analysis_(*tensor_array_side_effect_analysis),
1194 func_use_fallback_tensor_(func_use_fallback_tensor),
1195 enable_while_parallel_iterations_(enable_while_parallel_iterations) {}
1196
matchAndRewrite(mlir::TF::WhileOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1197 mlir::LogicalResult matchAndRewrite(
1198 mlir::TF::WhileOp op, OpAdaptor adaptor,
1199 mlir::ConversionPatternRewriter &rewriter) const override {
1200 mlir::FlatSymbolRefAttr cond_fn = op.condAttr();
1201 mlir::FlatSymbolRefAttr body_fn = op.bodyAttr();
1202
1203 llvm::SmallVector<Type, 4> while_arg_result_types;
1204 // Insert a chain for side effects as the first argument/result.
1205 while_arg_result_types.push_back(
1206 rewriter.getType<tfrt::compiler::ChainType>());
1207 for (mlir::Type type : op.getOperation()->getResultTypes()) {
1208 if (failed(type_converter_.convertType(type, while_arg_result_types)))
1209 return failure();
1210 }
1211
1212 // Convert the operands to either fallback tensor or corert tensors as
1213 // specified in the option.
1214 llvm::SmallVector<mlir::Value, 4> new_operands;
1215 if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(),
1216 &new_operands, rewriter,
1217 func_use_fallback_tensor_)))
1218 return failure();
1219
1220 // Create the predicate function that calls the original cond function and
1221 // in addition convert the result to a boolean value.
1222 mlir::func::FuncOp pred_fn =
1223 GetPredicateFunction(op, cond_fn, while_arg_result_types, rewriter);
1224 if (!pred_fn) return failure();
1225
1226 auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
1227 auto out_chain = in_chain;
1228
1229 bool has_at_most_tensor_array_effect = HasAtMostTensorArrayEffect(op);
1230
1231 // Prepare the arguments to call the pred function for the first iteration.
1232 llvm::SmallVector<mlir::Value, 4> pred_args;
1233 pred_args.push_back(
1234 has_at_most_tensor_array_effect ? GetFunctionInputChain(op) : in_chain);
1235 pred_args.append(new_operands.begin(), new_operands.end());
1236
1237 // Insert a call op to call the pred function for the first iteration.
1238 auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1239 op.getLoc(), pred_fn.getFunctionType().getResults(),
1240 pred_fn.getSymName(), pred_args);
1241
1242 auto pred_chain = call_pred_fn.getResult(0);
1243 auto first_iteration_bool_cond = call_pred_fn.getResult(1);
1244
1245 mlir::func::FuncOp new_body_fn = GetWhileBodyFunction(
1246 op, body_fn, pred_fn, while_arg_result_types, rewriter);
1247
1248 // Use the pred chain as the chain to the while body. The rest args should
1249 // be the same as the pred_args.
1250 auto &while_args = pred_args;
1251 while_args[0] = pred_chain;
1252
1253 int64_t parallel_iterations =
1254 enable_while_parallel_iterations_ ? op.parallel_iterations() : 1;
1255
1256 auto new_op = rewriter.create<tfrt::compiler::WhileOp>(
1257 op.getLoc(), while_arg_result_types, first_iteration_bool_cond,
1258 while_args, new_body_fn.getSymName(), parallel_iterations);
1259
1260 rewriter.replaceOp(op, new_op.getResults().drop_front());
1261
1262 if (!has_at_most_tensor_array_effect) out_chain = new_op.getResult(0);
1263
1264 if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1265 // Register the converted op so that it can be retrieved by successors.
1266 // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1267 // result is a chain.
1268 corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
1269 }
1270 return success();
1271 }
1272
1273 private:
HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const1274 bool HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const {
1275 return tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1276 op.cond_function()) &&
1277 tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1278 op.body_function());
1279 }
1280
1281 mlir::func::FuncOp GetPredicateFunction(
1282 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1283 mlir::TypeRange arg_types,
1284 mlir::ConversionPatternRewriter &rewriter) const;
1285
1286 mlir::func::FuncOp GetWhileBodyFunction(
1287 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr body_fn,
1288 mlir::func::FuncOp pred_fn, mlir::TypeRange arg_types,
1289 mlir::ConversionPatternRewriter &rewriter) const;
1290
1291 mlir::TypeConverter &type_converter_;
1292 CoreRTConverter &corert_converter_;
1293 mlir::SymbolTable &symbol_table_;
1294 const tfrt_compiler::TensorArraySideEffectAnalysis
1295 &tensor_array_side_effect_analysis_;
1296 bool func_use_fallback_tensor_;
1297 bool enable_while_parallel_iterations_;
1298 };
1299
1300 // Create the pred function that contains a call to the original cond function
1301 // and a predicate kernel that converts the cond tensor to a boolean value. eg.
1302 //
1303 // func @pred_fn(%ch, %arg) {
1304 // %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1305 // %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1306 // return %ch0, %cond_bool
1307 // }
1308 //
GetPredicateFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr cond_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1309 mlir::func::FuncOp TFRTWhileOpConversion::GetPredicateFunction(
1310 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1311 mlir::TypeRange arg_types,
1312 mlir::ConversionPatternRewriter &rewriter) const {
1313 std::string pred_fn_name = cond_fn.getValue().str() + "/tfrt_predicate";
1314
1315 if (auto pred_fn = symbol_table_.lookup<mlir::func::FuncOp>(pred_fn_name)) {
1316 return pred_fn;
1317 }
1318
1319 auto func_op = op->getParentOfType<mlir::func::FuncOp>();
1320
1321 mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1322 rewriter.setInsertionPointAfter(func_op);
1323
1324 std::array<mlir::Type, 2> pred_result_types = {
1325 rewriter.getType<tfrt::compiler::ChainType>(), rewriter.getI1Type()};
1326
1327 auto func_type = rewriter.getFunctionType(arg_types, pred_result_types);
1328
1329 auto pred_fn =
1330 rewriter.create<mlir::func::FuncOp>(op.getLoc(), pred_fn_name, func_type);
1331
1332 auto *block = pred_fn.addEntryBlock();
1333 rewriter.setInsertionPointToStart(block);
1334
1335 // There are at least two arguments, with the first being !tfrt.chain and the
1336 // second being either !tfrt_fallback.tf_tensor or !corert.tensorhandle.
1337 // cond_fn must have two results. The first of them must also be !tfrt.chain
1338 // and the second must have the same tensor type as arguments. So we can just
1339 // use the first two arg types as the result types.
1340 assert(arg_types.size() >= 2);
1341 auto cond_result_types = arg_types.take_front(2);
1342
1343 auto call_cond_fn = rewriter.create<tfrt::compiler::CallOp>(
1344 op.getLoc(), cond_result_types, cond_fn, block->getArguments());
1345
1346 auto chain = call_cond_fn.getResult(0);
1347 auto cond = call_cond_fn.getResult(1);
1348
1349 auto bool_cond = GetPredicate(op, cond, rewriter);
1350 if (!bool_cond) return {};
1351
1352 llvm::SmallVector<mlir::Value, 2> results = {chain, bool_cond};
1353
1354 rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), results);
1355
1356 symbol_table_.insert(pred_fn);
1357
1358 return pred_fn;
1359 }
1360
1361 // Create the new while body function that contains a call to original while
1362 // body and then a call to the pred function. eg.
1363 //
1364 // func @new_while_body(%ch, %arg) {
1365 // %ch0, %result = tfrt.call @original_body(%ch, %arg)
1366 // %ch1, %cond_bool = tfrt.call @pred_function(%ch0, %arg)
1367 // tfrt.return %ch1, %result, %cond_bool
1368 // }
1369 //
GetWhileBodyFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr original_body_fn,mlir::func::FuncOp pred_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1370 mlir::func::FuncOp TFRTWhileOpConversion::GetWhileBodyFunction(
1371 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr original_body_fn,
1372 mlir::func::FuncOp pred_fn, mlir::TypeRange arg_types,
1373 mlir::ConversionPatternRewriter &rewriter) const {
1374 int64_t parallel_iterations =
1375 enable_while_parallel_iterations_ ? op.parallel_iterations() : 1;
1376
1377 std::string body_fn_name = original_body_fn.getValue().str() + "/tfrt_body_" +
1378 absl::StrCat(parallel_iterations);
1379
1380 if (auto body_fn = symbol_table_.lookup<mlir::func::FuncOp>(body_fn_name)) {
1381 return body_fn;
1382 }
1383
1384 auto func_op = op->getParentOfType<mlir::func::FuncOp>();
1385
1386 mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1387 rewriter.setInsertionPointAfter(func_op);
1388
1389 auto while_result_types = arg_types;
1390
1391 llvm::SmallVector<mlir::Type, 4> body_result_types(arg_types.begin(),
1392 arg_types.end());
1393
1394 // The last result of the while body function is the boolean condition.
1395 body_result_types.push_back(rewriter.getI1Type());
1396 auto func_type = rewriter.getFunctionType(arg_types, body_result_types);
1397 auto body_fn =
1398 rewriter.create<mlir::func::FuncOp>(op.getLoc(), body_fn_name, func_type);
1399 if (parallel_iterations > 1) {
1400 // Disable stream merging by setting cost threshold to 1. The key to
1401 // parallelize while iterations is to execute iteration index handling (e.g.
1402 // ++i; i < max_iterations) in parallel to loop bodies. Since iteration
1403 // index handling is usually cheap when this while loop is parallelizable,
1404 // we don't want to merge it with loop bodies into one stream of inline
1405 // execution. The quickest way to achieve this is to disable stream merging
1406 // for loop functions. The potential downside is stream merging won't be
1407 // applied to other part of the loop body, so there might be excessive
1408 // threading overhead.
1409 //
1410 // TODO(chky): Consider a better way of parallizing while ops, so that we
1411 // can have stream merging applied within the loop body. One option is to
1412 // perform compiler transformation to extract iteration index handling logic
1413 // out of the loop body and convert it to a parallel_map-like op.
1414 body_fn->setAttr("tfrt.cost_threshold", rewriter.getI64IntegerAttr(1));
1415 }
1416
1417 auto *block = body_fn.addEntryBlock();
1418 rewriter.setInsertionPointToStart(block);
1419
1420 // Insert a call to the original body function.
1421 auto call_original_body_fn = rewriter.create<tfrt::compiler::CallOp>(
1422 op.getLoc(), while_result_types, original_body_fn, block->getArguments());
1423
1424 // Insert a call to the pred function, which contains a call to the original
1425 // cond function and the predicate kernel that converts the tensor to boolean
1426 // value.
1427 auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1428 op.getLoc(), pred_fn.getFunctionType().getResults(), pred_fn.getSymName(),
1429 call_original_body_fn.getResults());
1430
1431 auto pred_chain = call_pred_fn.getResult(0);
1432
1433 llvm::SmallVector<mlir::Value, 4> body_results;
1434 // The first result should be the chain returned from the pred function.
1435 body_results.push_back(pred_chain);
1436 // Then comes the results from the orignal while body excluding the first
1437 // chain (which is replaced by the `pred_chain`).
1438 for (auto res : llvm::drop_begin(call_original_body_fn.getResults())) {
1439 body_results.push_back(res);
1440 }
1441 // The last result should be the boolean value converted from the condition.
1442 auto bool_cond = call_pred_fn.getResult(1);
1443 body_results.push_back(bool_cond);
1444
1445 rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), body_results);
1446
1447 symbol_table_.insert(body_fn);
1448
1449 return body_fn;
1450 }
1451
1452 // TODO(ezhulenev): tf_device.cluster operations after auto-fusion should
1453 // have the correct device assigned based on the fused operations. We should
1454 // use this device to convert operands and results from/to corert handles.
1455 // For now it is safe to assume that it is "CPU" because we do not support
1456 // any other devices and do not support distributed models.
1457 constexpr char kJitRtDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
1458
1459 // Convert jitrt.call operations to the tf_jitrt.fallback.execute operation.
1460 class JitRtCallToJitRtCompileAndExecuteConversion
1461 : public OpConversionPattern<tfrt::jitrt::CallOp> {
1462 public:
JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext * context)1463 explicit JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext *context)
1464 : OpConversionPattern<tfrt::jitrt::CallOp>(context) {}
1465
matchAndRewrite(tfrt::jitrt::CallOp call,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1466 LogicalResult matchAndRewrite(
1467 tfrt::jitrt::CallOp call, OpAdaptor adaptor,
1468 ConversionPatternRewriter &rewriter) const override {
1469 // Convert operands to fallback tensors.
1470 llvm::SmallVector<Value, 4> fallback_operands;
1471 if (failed(tfrt_compiler::ConvertFallbackOperands(
1472 call, kJitRtDevice, adaptor.getOperands(), &fallback_operands,
1473 rewriter)))
1474 return rewriter.notifyMatchFailure(call, "failed to convert operand");
1475
1476 // tf_jitrt.fallback.execute always produces fallback tensors.
1477 llvm::SmallVector<Type, 4> result_types(
1478 call->getNumResults(),
1479 rewriter.getType<tfrt::fallback::TFTensorType>());
1480
1481 // Replace jitrt.call operation with a tf_jitrt.fallback.execute operation.
1482 rewriter.replaceOpWithNewOp<tf_jitrt::FallbackExecuteOp>(
1483 call, result_types, call.callee(), fallback_operands, kJitRtDevice);
1484
1485 return success();
1486 }
1487 };
1488
1489 // Helper function for specifying legal dialects for conversion to CoreRT.
SetUpTFToTFRTConversionLegality(mlir::ConversionTarget * target,mlir::TypeConverter * func_type_converter,mlir::Type chain_type)1490 void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target,
1491 mlir::TypeConverter *func_type_converter,
1492 mlir::Type chain_type) {
1493 target->addLegalDialect<tfrt::corert::CoreRTDialect>();
1494 target->addLegalDialect<tfrt::fallback_async::FallbackAsyncDialect>();
1495 target->addLegalDialect<tfrt::compiler::TFRTDialect>();
1496 target->addLegalDialect<tfrt::dist::DistributedDialect>();
1497 target->addLegalDialect<tfrt::test::TestDialect>();
1498 target->addLegalDialect<tf_jitrt::JitRuntimeDialect>();
1499 target->addIllegalDialect<TF::TensorFlowDialect>();
1500 target->addIllegalDialect<tf_device::TensorFlowDeviceDialect>();
1501 target->addIllegalDialect<tfrt::jitrt::JitRuntimeDialect>();
1502 target->addDynamicallyLegalOp<mlir::func::FuncOp>([func_type_converter,
1503 chain_type](
1504 func::FuncOp op) {
1505 auto func_type = op.getFunctionType();
1506 if (func_type.getNumInputs() == 0 || func_type.getInput(0) != chain_type)
1507 return false;
1508 if (func_type.getNumResults() == 0 || func_type.getResult(0) != chain_type)
1509 return false;
1510
1511 return func_type_converter->isSignatureLegal(op.getFunctionType()) &&
1512 func_type_converter->isLegal(&op.getBody());
1513 });
1514 }
1515
1516 // Helper function for inserting TFRT JitRt dialect conversions.
PopulateJitRtConversionPatterns(MLIRContext * context,RewritePatternSet * patterns,CoreRTConverter * corert_converter)1517 void PopulateJitRtConversionPatterns(MLIRContext *context,
1518 RewritePatternSet *patterns,
1519 CoreRTConverter *corert_converter) {
1520 // Lower jitrt.call to the pair of compile and execute operations.
1521 patterns->add<JitRtCallToJitRtCompileAndExecuteConversion>(context);
1522 }
1523
1524 // Helper function for inserting TF dialect to TFRT dialect op conversion
1525 // patterns.
PopulateTFToTFRTConversionPatterns(mlir::MLIRContext * context,mlir::RewritePatternSet * patterns,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::CostAnalysis * cost_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool enable_native_ops,bool func_use_fallback_tensor,bool enable_while_parallel_iterations,bool tpu_lower_to_fallback,bool target_tpurt)1526 void PopulateTFToTFRTConversionPatterns(
1527 mlir::MLIRContext *context, mlir::RewritePatternSet *patterns,
1528 CoreRTConverter *corert_converter,
1529 tfrt_compiler::FallbackConverter *fallback_converter,
1530 mlir::SymbolTable *symbol_table,
1531 const tfrt_compiler::CostAnalysis *cost_analysis,
1532 const tfrt_compiler::TensorArraySideEffectAnalysis
1533 *tensor_array_side_effect_analysis,
1534 bool enable_native_ops, bool func_use_fallback_tensor,
1535 bool enable_while_parallel_iterations, bool tpu_lower_to_fallback,
1536 bool target_tpurt) {
1537 // By default, we lower all TF ops to fallback ops.
1538 patterns->add<FallbackExecuteOpConversion>(
1539 context, corert_converter, fallback_converter, cost_analysis,
1540 tpu_lower_to_fallback, target_tpurt);
1541 patterns->add<FallbackConstOpConversion, FallbackSetResourceOp,
1542 FallbackGetResourceOp>(context, corert_converter);
1543
1544 // For control flow ops, we handle them according to the option.
1545 mlir::TypeConverter *func_type_converter;
1546 if (func_use_fallback_tensor) {
1547 func_type_converter = fallback_converter;
1548 } else {
1549 func_type_converter = corert_converter;
1550 }
1551 patterns->add<TFRTFuncOpSignatureConversion>(context, func_type_converter);
1552 patterns->add<TFRTReturnOpConversion>(context, corert_converter,
1553 func_use_fallback_tensor);
1554 patterns->add<TFRTWhileOpConversion>(
1555 context, func_type_converter, corert_converter, symbol_table,
1556 tensor_array_side_effect_analysis, func_use_fallback_tensor,
1557 enable_while_parallel_iterations);
1558 patterns->add<TFRTCallOpConversion<mlir::TF::StatefulPartitionedCallOp>,
1559 TFRTCallOpConversion<mlir::TF::PartitionedCallOp>,
1560 TFRTCallOpConversion<mlir::TF::LegacyCallOp>,
1561 TFRTCaseOpConversion, TFRTCondOpConversion>(
1562 context, func_type_converter, corert_converter, func_use_fallback_tensor);
1563
1564 // For tf.BatchFunction, we need a special fallback op to batch a BEF
1565 // function.
1566 patterns->add<FallbackBatchFunctionOpConversion>(context, corert_converter);
1567
1568 // Below patterns are preferred over fallback lowering as we want to use
1569 // CoreRT interface for native kernels. This is only temporary and it will
1570 // refactored to use SSOT interface.
1571
1572 // Here we use specialized patterns for tf.Const on CPU as it is incorrect to
1573 // use ExecuteOp pattern to convert string tensor attribute.
1574 patterns->add<CoreRTConstStringTensorOpConversion,
1575 CoreRTConstDenseTensorOpConversion>(context, corert_converter);
1576
1577 if (enable_native_ops) {
1578 // Below TF operations will be converted to use corert.executeop, which will
1579 // invoke TFRT native op if implemented.
1580 // TODO(b/187942369): Pattern registration for TF operations is not
1581 // sustainable currently. We need to figure out a plan.
1582 patterns->add<CoreRTExecuteOpConversion<TF::AddV2Op>,
1583 // TODO(chky): Move the ReadVariableOp + Identity pattern
1584 // to optimizer.
1585 // CoreRTExecuteOpConversion<TF::IdentityOp>,
1586 CoreRTExecuteOpConversion<TF::MulOp>,
1587 CoreRTExecuteOpConversion<TF::BiasAddOp>,
1588 CoreRTExecuteOpConversion<TF::Conv2DOp>,
1589 CoreRTExecuteOpConversion<TF::ConcatV2Op>,
1590 CoreRTExecuteOpConversion<TF::CastOp>,
1591 CoreRTExecuteOpConversion<TF::ExpandDimsOp>,
1592 CoreRTExecuteOpConversion<TF::TransposeOp>,
1593 CoreRTExecuteOpConversion<TF::FusedBatchNormV3Op>,
1594 CoreRTExecuteOpConversion<TF::_FusedBatchNormExOp>,
1595 CoreRTExecuteOpConversion<TF::LogOp>,
1596 CoreRTExecuteOpConversion<TF::Log1pOp>,
1597 CoreRTExecuteOpConversion<TF::LogSoftmaxOp>,
1598 CoreRTExecuteOpConversion<TF::MatMulOp>,
1599 CoreRTExecuteOpConversion<TF::_FusedMatMulOp>,
1600 CoreRTExecuteOpConversion<TF::MaxPoolOp>,
1601 CoreRTExecuteOpConversion<TF::MeanOp>,
1602 CoreRTExecuteOpConversion<TF::MulOp>,
1603 CoreRTExecuteOpConversion<TF::PadOp>,
1604 CoreRTExecuteOpConversion<TF::RealDivOp>,
1605 CoreRTExecuteOpConversion<TF::ReluOp>,
1606 CoreRTExecuteOpConversion<TF::ReshapeOp>,
1607 CoreRTExecuteOpConversion<TF::RsqrtOp>,
1608 CoreRTExecuteOpConversion<TF::SoftmaxOp>,
1609 CoreRTExecuteOpConversion<TF::ShapeOp>,
1610 CoreRTExecuteOpConversion<TF::SigmoidOp>,
1611 CoreRTExecuteOpConversion<TF::SubOp>,
1612 CoreRTExecuteOpConversion<TF::TileOp>,
1613 CoreRTExecuteOpConversion<TF::TanhOp>,
1614 CoreRTExecuteOpConversion<TF::ZerosLikeOp>>(context,
1615 corert_converter);
1616 }
1617 }
1618
1619 // Lower TF dialect MLIR to TFRT dialect.
1620 class TfToTfrtConversionPass
1621 : public mlir::PassWrapper<TfToTfrtConversionPass,
1622 mlir::OperationPass<mlir::ModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const1623 void getDependentDialects(mlir::DialectRegistry ®istry) const override {
1624 getDependentConversionDialects(registry);
1625
1626 if (target_tpurt_) RegisterTPUDialects(®istry);
1627 }
1628
getArgument() const1629 llvm::StringRef getArgument() const final { return "tf-to-tfrt"; }
getDescription() const1630 llvm::StringRef getDescription() const final {
1631 return "Convert Tensorflow dialect (generated from tf.function) to TFRT "
1632 "dialect.";
1633 }
1634
1635 public:
1636 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TfToTfrtConversionPass)
1637
1638 TfToTfrtConversionPass() = default;
TfToTfrtConversionPass(const TfrtPipelineOptions & options)1639 explicit TfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1640 target_tpurt_ = options.target_tpurt;
1641 enable_native_ops_ = options.enable_native_ops;
1642 tpu_use_core_selector_ = options.tpu_use_core_selector;
1643 tpu_use_bundled_transfer_ = options.tpu_use_bundled_transfer;
1644 tpu_lower_to_fallback_ = options.tpu_lower_to_fallback;
1645 tpu_transfer_result_to_host_ = options.tpu_transfer_result_to_host;
1646 use_tpu_host_allocator_for_inputs_ =
1647 options.use_tpu_host_allocator_for_inputs;
1648 cost_threshold_ = options.cost_threshold;
1649 upper_cost_threshold_ = options.upper_cost_threshold;
1650 merge_inter_dependent_streams_ = options.merge_inter_dependent_streams;
1651 func_use_fallback_tensor_ = options.func_use_fallback_tensor;
1652 enable_while_parallel_iterations_ =
1653 options.enable_while_parallel_iterations;
1654 }
TfToTfrtConversionPass(const TfToTfrtConversionPass &)1655 TfToTfrtConversionPass(const TfToTfrtConversionPass &) {}
1656
runOnFunction(mlir::func::FuncOp func,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis & tensor_array_side_effect_analysis,tfrt_compiler::FallbackConverter & fallback_converter,mlir::SymbolTable & symbol_table)1657 mlir::LogicalResult runOnFunction(
1658 mlir::func::FuncOp func,
1659 const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
1660 const tfrt_compiler::TensorArraySideEffectAnalysis
1661 &tensor_array_side_effect_analysis,
1662 tfrt_compiler::FallbackConverter &fallback_converter,
1663 mlir::SymbolTable &symbol_table) {
1664 auto &context = getContext();
1665 mlir::ConversionTarget target(context);
1666 mlir::RewritePatternSet patterns(&getContext());
1667 CoreRTConverter corert_converter(&context, &side_effect_analysis);
1668 tfrt_compiler::CostAnalysis cost_analysis(func);
1669
1670 if (target_tpurt_)
1671 AddTPUTargetDialectAndPatterns(
1672 &target, &patterns, &context, &corert_converter, &fallback_converter,
1673 TfrtTpuExecuteOpConversionOptions{
1674 tpu_use_core_selector_, tpu_use_bundled_transfer_,
1675 tpu_transfer_result_to_host_, use_tpu_host_allocator_for_inputs_},
1676 tpu_lower_to_fallback_);
1677
1678 mlir::TypeConverter *func_type_converter;
1679 if (func_use_fallback_tensor_) {
1680 func_type_converter = &fallback_converter;
1681 } else {
1682 func_type_converter = &corert_converter;
1683 }
1684 SetUpTFToTFRTConversionLegality(&target, func_type_converter,
1685 corert_converter.chain_type());
1686 PopulateJitRtConversionPatterns(&context, &patterns, &corert_converter);
1687
1688 PopulateTFToTFRTConversionPatterns(
1689 &context, &patterns, &corert_converter, &fallback_converter,
1690 &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis,
1691 enable_native_ops_, func_use_fallback_tensor_,
1692 enable_while_parallel_iterations_, tpu_lower_to_fallback_,
1693 target_tpurt_);
1694
1695 return mlir::applyPartialConversion(func, target, std::move(patterns));
1696 }
1697
runOnOperation()1698 void runOnOperation() override {
1699 auto module = getOperation();
1700 const auto &side_effect_analysis =
1701 getAnalysis<mlir::TF::SideEffectAnalysis>();
1702
1703 tfrt_compiler::TensorArraySideEffectAnalysis
1704 tensor_array_side_effect_analysis(module);
1705 tfrt_compiler::FallbackConverter fallback_converter(&getContext());
1706
1707 mlir::SymbolTable symbol_table(module);
1708
1709 auto func_op_range = module.getOps<mlir::func::FuncOp>();
1710 llvm::SmallVector<mlir::func::FuncOp, 4> func_ops(func_op_range.begin(),
1711 func_op_range.end());
1712 for (auto func : func_ops) {
1713 if (!func.isExternal()) {
1714 if (mlir::failed(runOnFunction(
1715 func, side_effect_analysis.GetAnalysisForFunc(func),
1716 tensor_array_side_effect_analysis, fallback_converter,
1717 symbol_table))) {
1718 signalPassFailure();
1719 return;
1720 }
1721
1722 ChainDanglingValuesinFunction(func);
1723 }
1724 }
1725
1726 CreateFallbackInitializationFunction(module, fallback_converter);
1727
1728 // Set cost threshold as a module attribute. It will be used later in the
1729 // runtime to decide whether certain execution is cheap enough to be inline
1730 // executed.
1731 mlir::Builder builder(module);
1732 module->setAttr("tfrt.cost_threshold",
1733 builder.getI64IntegerAttr(cost_threshold_));
1734 module->setAttr("tfrt.upper_cost_threshold",
1735 builder.getI64IntegerAttr(upper_cost_threshold_));
1736 module->setAttr("tfrt.merge_inter_dependent_streams",
1737 builder.getBoolAttr(merge_inter_dependent_streams_));
1738 }
1739
1740 private:
1741 // Chain all dangling values (ie. values with no users) together and merge it
1742 // with the first returned chain. This merged chain can be used to signal the
1743 // completion of all execution including side-effets.
ChainDanglingValuesinFunction(mlir::func::FuncOp func_op)1744 void ChainDanglingValuesinFunction(mlir::func::FuncOp func_op) {
1745 auto &block = func_op.front();
1746
1747 llvm::SmallVector<mlir::Value, 2> dangling_values;
1748
1749 // Find dangling function arguments.
1750 for (auto arg : block.getArguments()) {
1751 if (arg.use_empty()) {
1752 dangling_values.push_back(arg);
1753 }
1754 }
1755
1756 // Find dangling values produced by ops in the function.
1757 for (auto &op : block) {
1758 for (mlir::Value result : op.getResults()) {
1759 if (result.use_empty()) {
1760 dangling_values.push_back(result);
1761 }
1762 }
1763 }
1764
1765 if (dangling_values.empty()) return;
1766
1767 // Merge these dangling values with the first returned chain.
1768 auto return_op =
1769 llvm::cast<tfrt::compiler::ReturnOp>(block.getTerminator());
1770 auto chain = return_op->getOperand(0);
1771 assert(chain.getType().isa<tfrt::compiler::ChainType>());
1772 dangling_values.push_back(chain);
1773
1774 mlir::OpBuilder builder(return_op);
1775
1776 auto new_chain = builder.create<tfrt::compiler::MergeChainsOp>(
1777 return_op->getLoc(), builder.getType<tfrt::compiler::ChainType>(),
1778 dangling_values);
1779
1780 return_op->setOperand(0, new_chain);
1781 }
1782
CreateFallbackInitializationFunction(mlir::ModuleOp module,tfrt_compiler::FallbackConverter & fallback_converter)1783 void CreateFallbackInitializationFunction(
1784 mlir::ModuleOp module,
1785 tfrt_compiler::FallbackConverter &fallback_converter) {
1786 mlir::OpBuilder builder(&module.getBodyRegion());
1787
1788 auto chain_type = builder.getType<tfrt::compiler::ChainType>();
1789
1790 auto func_op = builder.create<mlir::func::FuncOp>(
1791 module.getLoc(), "_tfrt_fallback_init",
1792 mlir::FunctionType::get(module.getContext(), /*inputs=*/chain_type,
1793 /*results=*/chain_type));
1794
1795 auto *block = func_op.addEntryBlock();
1796 builder.setInsertionPointToStart(block);
1797
1798 mlir::Value chain_value = block->getArgument(0);
1799
1800 // Create operations for all fallback kernels in the module.
1801 for (auto *op : fallback_converter.GetFallbackOps()) {
1802 auto create_op = builder.create<tfrt::fallback_async::CreateOp>(
1803 func_op.getLoc(), chain_type, chain_value);
1804
1805 create_op->setAttrs(op->getAttrs());
1806 create_op->setAttr("num_args", builder.getI64IntegerAttr(GetNumArgs(op)));
1807
1808 chain_value = create_op;
1809 }
1810
1811 // Pre-compile all JIT compiled kernels found in the module.
1812 llvm::SmallVector<Value> compiled;
1813
1814 // A set SymbolRef attributes referencing compiled kernels.
1815 llvm::DenseSet<mlir::Attribute> kernels;
1816
1817 // Compile all kernels in parallell.
1818 module.walk([&](tf_jitrt::FallbackExecuteOp execute) {
1819 // Do not compiled the same kernel multiple times.
1820 if (kernels.contains(execute.kernel())) return;
1821
1822 auto compile = builder.create<tf_jitrt::FallbackCompileOp>(
1823 execute.getLoc(), chain_type, execute.kernel(), execute.device());
1824 compiled.push_back(compile.getResult());
1825 kernels.insert(compile.kernel());
1826 });
1827
1828 // Wait for the compilation completion before returning from init function.
1829 if (!compiled.empty()) {
1830 // Do not forget to wait for the fallback kernels initialization.
1831 compiled.insert(compiled.begin(), chain_value);
1832 chain_value = builder.create<tfrt::compiler::MergeChainsOp>(
1833 func_op.getLoc(), chain_type, compiled);
1834 }
1835
1836 builder.create<tfrt::compiler::ReturnOp>(func_op.getLoc(), chain_value);
1837 }
1838
GetNumArgs(mlir::Operation * fallback_op)1839 int64_t GetNumArgs(mlir::Operation *fallback_op) {
1840 if (auto execute_op =
1841 llvm::dyn_cast<tfrt::fallback_async::ExecuteOp>(fallback_op)) {
1842 return execute_op.operands().size();
1843 } else if (auto execute_op_seq =
1844 llvm::dyn_cast<tfrt::fallback_async::ExecuteOpSeq>(
1845 fallback_op)) {
1846 return execute_op_seq.operands().size();
1847 } else if (auto execute_op_allocator =
1848 llvm::dyn_cast<tfrt::fallback_async::ExecuteOpWithAllocator>(
1849 fallback_op)) {
1850 return execute_op_allocator.operands().size();
1851 } else if (auto execute_op_seq_allocator = llvm::dyn_cast<
1852 tfrt::fallback_async::ExecuteOpSeqWithAllocator>(
1853 fallback_op)) {
1854 return execute_op_seq_allocator.operands().size();
1855 }
1856 llvm_unreachable("invalid fallback op type");
1857 }
1858
1859 Option<bool> target_tpurt_{*this, "target-tpurt",
1860 llvm::cl::desc("Target TPURT dialect if true."),
1861 llvm::cl::init(false)};
1862
1863 Option<bool> enable_native_ops_{
1864 *this, "enable-native-ops",
1865 llvm::cl::desc(
1866 "If true, native ops will be used on an opt-in basis "
1867 "instead of fallback ops. If false, no native ops are used."),
1868 llvm::cl::init(true)};
1869
1870 Option<bool> tpu_use_core_selector_{
1871 *this, "tpu-use-core-selector",
1872 llvm::cl::desc("If true, use ServingCoreSelector to pick TPU core. "
1873 "Otherwise, use the assigned core."),
1874 llvm::cl::init(true)};
1875
1876 Option<bool> tpu_use_bundled_transfer_{
1877 *this, "tpu-use-bundled-transfer",
1878 llvm::cl::desc("If true, use BundledTransferToTpuOp to transfer "
1879 "variables and input tensors to TPU."),
1880 llvm::cl::init(true)};
1881
1882 Option<bool> tpu_lower_to_fallback_{
1883 *this, "tpu-lower-to-fallback",
1884 llvm::cl::desc("If true, lower a TF op that's placed on TPU device "
1885 "to be executed by tfrt_fallback.execute. Note that this "
1886 "applies to ops other than TPU compile and execute ops."),
1887 llvm::cl::init(true)};
1888
1889 // TODO(b/194081364): remove this option once we unify servo TPU serving
1890 // result transfer behavior.
1891 Option<bool> tpu_transfer_result_to_host_{
1892 *this, "tpu-transfer-result-to-host",
1893 llvm::cl::desc("If true, transfer the result of tpurt.execute from TPU "
1894 "to host."),
1895 llvm::cl::init(true)};
1896
1897 Option<bool> use_tpu_host_allocator_for_inputs_{
1898 *this, "use-tpu-host-allocator-for-inputs",
1899 llvm::cl::desc("If true, fallback executeops that produce inputs to tpu "
1900 "program will use tpu host allocator."),
1901 llvm::cl::init(false)};
1902
1903 Option<uint64_t> cost_threshold_{
1904 *this, "tfrt-cost-threshold",
1905 llvm::cl::desc(
1906 "The cost threshold to decide whether a sequence of operations is "
1907 "cheap, and then whether it can be executed inline."),
1908 llvm::cl::init(1)};
1909
1910 Option<int64_t> upper_cost_threshold_{
1911 *this, "tfrt-upper-cost-threshold",
1912 llvm::cl::desc(
1913 "The threshold to limit the merging of dependent sequence."),
1914 llvm::cl::init(-1)};
1915
1916 Option<bool> merge_inter_dependent_streams_{
1917 *this, "tfrt-merge-inter-dependent-streams",
1918 llvm::cl::desc("If true, streams with inter data depenedencies will be "
1919 "preferred to be merged for inline execution."),
1920 llvm::cl::init(false)};
1921
1922 Option<bool> func_use_fallback_tensor_{
1923 *this, "func-use-fallback-tensor",
1924 llvm::cl::desc(
1925 "If true, use TF tensor as input/output types in func (and other "
1926 "control flow) ops."),
1927 llvm::cl::init(false)};
1928
1929 Option<bool> enable_while_parallel_iterations_{
1930 *this, "enable-while-parallel-iterations",
1931 llvm::cl::desc("If true, tf.While op will be parallelized. This is "
1932 "currently experimental."),
1933 llvm::cl::init(false)};
1934 };
1935
1936 // Assigns devices so that later passes can utilize device information.
1937 // Device assignement might have not been done by the upstream pipeline, or get
1938 // removed by previous passes. However, we assume most of the device assignment
1939 // has been done by the upstream pipeline, so we simply assign the default
1940 // device to unassigned ops. Specifically, we do assignment for ConstOp first to
1941 // place it on the same device as its user operation, instead of placing it on
1942 // the default device blindly.
1943 // TODO(b/221297389): Figure out a more robust way to handle dropped device
1944 // assignment.
AddTfDeviceAssignmentPasses(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)1945 void AddTfDeviceAssignmentPasses(mlir::OpPassManager &pm,
1946 const TfrtPipelineOptions &options) {
1947 pm.addPass(mlir::TF::CreateConstantOpDeviceAssignmentPass());
1948 pm.addNestedPass<mlir::func::FuncOp>(
1949 mlir::TF::CreateTFDeviceAssignmentByFuncAttrPass());
1950 pm.addNestedPass<mlir::func::FuncOp>(
1951 mlir::TF::CreateSimpleTFDeviceAssignmentPass(options.default_device));
1952 }
1953
1954 } // namespace
1955
1956 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTfToTfrtConversionPass(const TfrtPipelineOptions & options)1957 CreateTfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1958 return std::make_unique<TfToTfrtConversionPass>(options);
1959 }
1960
1961 // -------------------------------------------------------------------------- //
1962 // Outline tf_device.cluster operation regions into functions in the nested
1963 // modules and replaces all cluster operations with jitrt.call operations.
1964 // -------------------------------------------------------------------------- //
1965
1966 class OutlineJitRtClustersPass
1967 : public PassWrapper<OutlineJitRtClustersPass, OperationPass<ModuleOp>> {
1968 public:
getArgument() const1969 llvm::StringRef getArgument() const final {
1970 return "tf-outline-jitrt-cluster";
1971 }
getDescription() const1972 llvm::StringRef getDescription() const final {
1973 return "Outlines `tf_device.cluster` operations into functions and "
1974 "replaces them with `jitrt.call` operations.";
1975 }
1976
1977 void runOnOperation() override;
1978
getDependentDialects(mlir::DialectRegistry & registry) const1979 void getDependentDialects(mlir::DialectRegistry ®istry) const override {
1980 registry.insert<tfrt::jitrt::JitRuntimeDialect>();
1981 }
1982
1983 private:
1984 struct CompiledModule {
1985 ModuleOp module;
1986 func::FuncOp entrypoint;
1987 llvm::SetVector<Value> operands;
1988 };
1989
1990 // Creates a nested module with a single function that will be compiled into
1991 // the kernel at runtime.
1992 CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster,
1993 int64_t max_arg_size,
1994 SymbolTable *symbol_table);
1995
1996 // Update compiled module entrypoint signature with inferred operands
1997 // constraints.
1998 LogicalResult SetEntrypointConstraints(CompiledModule &compiled);
1999
2000 // Outlines cluster operation regions into compiled modules, and replaces
2001 // cluster operation with a jitrt.call operation.
2002 LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster,
2003 int64_t max_arg_size,
2004 SymbolTable *symbol_table);
2005
2006 // Mapping from the outlined module string representation to the module itself
2007 // and an entrypoint function. Used to deduplicate identical modules during
2008 // the `tf_device.cluster` outlining.
2009 llvm::StringMap<std::pair<ModuleOp, func::FuncOp>> outlined_;
2010 };
2011
2012 OutlineJitRtClustersPass::CompiledModule
CreateCompiledModule(tf_device::ClusterOp cluster,int64_t max_arg_size,SymbolTable * symbol_table)2013 OutlineJitRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster,
2014 int64_t max_arg_size,
2015 SymbolTable *symbol_table) {
2016 MLIRContext *ctx = cluster->getContext();
2017 Location loc = cluster.getLoc();
2018
2019 // Create a module that will hold compiled function and async wrappers.
2020 // TODO(ezhulenev): Give better names to module and function.
2021 auto compiled_module = ModuleOp::create(loc, {"kernel"});
2022 compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx));
2023 compiled_module->setAttr(
2024 "tfrt.max-arg-size",
2025 IntegerAttr::get(IntegerType::get(ctx, 64), max_arg_size));
2026
2027 SymbolTable compiled_module_symbol_table(compiled_module);
2028
2029 // Find out the cluster arguments and their types.
2030 llvm::SetVector<Value> live_ins;
2031 getUsedValuesDefinedAbove(cluster.body(), cluster.body(), live_ins);
2032
2033 llvm::SmallVector<Type, 4> operand_types;
2034 operand_types.reserve(live_ins.size());
2035 for (Value v : live_ins) operand_types.emplace_back(v.getType());
2036
2037 // Create a function in the compiled module.
2038 auto compiled_func_type =
2039 FunctionType::get(ctx, operand_types, cluster->getResultTypes());
2040 auto compiled_func = func::FuncOp::create(loc, "compute", compiled_func_type);
2041 compiled_module_symbol_table.insert(compiled_func);
2042
2043 // Replace uses of live-in values within cluster region with block arguments.
2044 Block *compiled_func_block = compiled_func.addEntryBlock();
2045 for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments()))
2046 replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), cluster.body());
2047
2048 // Move all operations in cluster into compiled_func's entry block.
2049 auto &cluster_body = cluster.GetBody().getOperations();
2050 compiled_func_block->getOperations().splice(
2051 compiled_func_block->end(), cluster_body, cluster_body.begin(),
2052 cluster_body.end());
2053
2054 // Replace `tf_device.return` terminator with `func.return` in the function
2055 // body.
2056 auto device_return =
2057 cast<tf_device::ReturnOp>(compiled_func_block->getTerminator());
2058 OpBuilder builder(device_return.getOperation());
2059 builder.create<func::ReturnOp>(device_return.getLoc(),
2060 device_return.getOperands());
2061 device_return.erase();
2062
2063 // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet,
2064 // replace module printing with a more principled solution when available.
2065 // Operations in the cluster can be in different order, however define the
2066 // identical Tensorflow programs, with current approach we'll not be able
2067 // to detect duplicates like this.
2068
2069 // Remove location attribute attached to Tensorflow operations to be able to
2070 // deduplicate compiled clusters with the same set of operations.
2071 //
2072 // TODO(ezhulenev): Figure out how to propagate locations for error reporting,
2073 // right now JitRt will ignore them anyway.
2074 compiled_module.walk([](Operation *op) { op->removeAttr("_class"); });
2075
2076 // Serialize prepared module to string.
2077 std::string serialized;
2078 llvm::raw_string_ostream os(serialized);
2079 compiled_module.print(os);
2080
2081 // Try to find if identical module was already outlined.
2082 auto it = outlined_.find(serialized);
2083
2084 // Return identical module that was already outlined earlier.
2085 if (it != outlined_.end()) {
2086 compiled_module.erase(); // erase identical module
2087 return {it->second.first, it->second.second, live_ins};
2088 }
2089
2090 // Insert compiled module into the symbol table and assign it a unique name.
2091 symbol_table->insert(compiled_module);
2092
2093 // Cache unique module.
2094 outlined_.insert({std::move(serialized), {compiled_module, compiled_func}});
2095
2096 return {compiled_module, compiled_func, live_ins};
2097 }
2098
SetEntrypointConstraints(CompiledModule & compiled)2099 LogicalResult OutlineJitRtClustersPass::SetEntrypointConstraints(
2100 CompiledModule &compiled) {
2101 func::FuncOp func = compiled.entrypoint;
2102
2103 // Functions outlined from jitrt device clusters must have a single block.
2104 assert(func.getBody().getBlocks().size() == 1 && "expected single block");
2105
2106 mlir::TFDevice::ClusteringPolicySet policies;
2107 populateTfJitRtConstraintsPolicies(policies);
2108
2109 // Infer constraints on the values defined in the entrypoint function
2110 // (including function entry block arguments).
2111 mlir::TFDevice::ValuesConstraintSet constraints;
2112 if (failed(mlir::TFDevice::PropagateValuesConstraints(
2113 func.getBody(), policies, constraints, /*resolve=*/true)))
2114 return failure();
2115
2116 // Annotate arguments with inferred constraints.
2117 for (unsigned i = 0; i < func.getNumArguments(); ++i) {
2118 if (auto constraint = constraints.GetConstraint(func.getArgument(i))) {
2119 auto constraint_name = mlir::StringAttr::get(
2120 &getContext(), llvm::formatv("{0}", *constraint).str());
2121 func.setArgAttr(i, "rt.constraint", constraint_name);
2122 }
2123 }
2124
2125 return success();
2126 }
2127
OutlineClusterOp(tf_device::ClusterOp cluster,int64_t max_arg_size,SymbolTable * symbol_table)2128 LogicalResult OutlineJitRtClustersPass::OutlineClusterOp(
2129 tf_device::ClusterOp cluster, int64_t max_arg_size,
2130 SymbolTable *symbol_table) {
2131 Location loc = cluster->getLoc();
2132 OpBuilder builder(cluster);
2133
2134 CompiledModule compiled_module =
2135 CreateCompiledModule(cluster, max_arg_size, symbol_table);
2136 func::FuncOp compiled_func = compiled_module.entrypoint;
2137
2138 // Add constraints to the entrypoint arguments.
2139 if (failed(SetEntrypointConstraints(compiled_module))) return failure();
2140
2141 // Replace device cluster with a jitrt.call operation.
2142 auto module_name = *compiled_module.module.getSymName();
2143 auto func_name = compiled_func.getSymName();
2144 auto func_flat_ref =
2145 mlir::SymbolRefAttr::get(builder.getContext(), func_name);
2146 auto func_ref = mlir::SymbolRefAttr::get(builder.getContext(), module_name,
2147 {func_flat_ref});
2148
2149 auto cluster_func_op = builder.create<tfrt::jitrt::CallOp>(
2150 loc, cluster.getResultTypes(), func_ref,
2151 compiled_module.operands.getArrayRef());
2152
2153 cluster.replaceAllUsesWith(cluster_func_op);
2154 cluster.erase();
2155
2156 return success();
2157 }
2158
runOnOperation()2159 void OutlineJitRtClustersPass::runOnOperation() {
2160 ModuleOp module = getOperation();
2161 SymbolTable symbol_table(module);
2162
2163 // Keep track of the maximum argument size for each function with tf_device
2164 // cluster operations in the function body. We need to pass it to the compiled
2165 // module to correctly compute its cost later.
2166 llvm::DenseMap<mlir::func::FuncOp, int64_t> max_arg_size_map;
2167
2168 auto get_max_arg_size = [&](mlir::func::FuncOp func) -> int64_t {
2169 auto it = max_arg_size_map.find(func);
2170 if (it != max_arg_size_map.end()) return it->second;
2171 return max_arg_size_map[func] = tf_jitrt::GetMaxArgSize(func);
2172 };
2173
2174 OpBuilder builder(module.getContext());
2175 auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult {
2176 // Ensure that cluster was formed for TFRT JIT compilation.
2177 auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2178 if (!policy || policy.getValue() != "tfrt.auto-fusion")
2179 return WalkResult::advance();
2180
2181 // Get the maximum argument size of the parent function.
2182 mlir::func::FuncOp parent_func =
2183 cluster->getParentOfType<mlir::func::FuncOp>();
2184 int64_t max_arg_size = get_max_arg_size(parent_func);
2185
2186 if (failed(OutlineClusterOp(cluster, max_arg_size, &symbol_table)))
2187 return WalkResult::interrupt();
2188 return WalkResult::advance();
2189 });
2190
2191 if (result.wasInterrupted()) {
2192 module->emitError("Failed to outline tf_device.cluster operations");
2193 signalPassFailure();
2194 }
2195 }
2196
CreateOutlineJitRtClustersPass()2197 static std::unique_ptr<Pass> CreateOutlineJitRtClustersPass() {
2198 return std::make_unique<OutlineJitRtClustersPass>();
2199 }
2200
2201 // -------------------------------------------------------------------------- //
2202
CreateTFExecutorToTFPipeline(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2203 void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm,
2204 const TfrtPipelineOptions &options) {
2205 // Due to b/191304670, functionalized while ops might not have the
2206 // shape_invariant attribute set correctly, which leads to failure in shape
2207 // inference. As a workaround, we conservatively (e.g., we place less
2208 // restrictions on tf.while which will avoid failures but lead to potentially
2209 // less exact shape inference) set the shape_invariant attribute in all
2210 // tf.While ops before performing shape inference.
2211 //
2212 // Note that this pass might not work well with TF XLA bridge, but this is
2213 // fine as TF XLA bridge is run before this pipeline. For CPU ops, less exact
2214 // shape inference may lead to fewer optimizations but it should be fine as it
2215 // is limited to while ops currently.
2216 //
2217 // TODO(b/191304670): Remove this pass once the shape_invariant attribute is
2218 // set correctly in the upstream.
2219 pm.addNestedPass<mlir::func::FuncOp>(
2220 tfrt_compiler::CreateSetShapeInvariantInWhileOps());
2221
2222 // We pass the MLIR module through the TF standard pipeline, which for
2223 // instances does shape inference, canonicalization, inlining, etc.
2224 pm.addNestedPass<mlir::func::FuncOp>(
2225 mlir::tf_executor::CreateTFExecutorGraphPruningPass());
2226 pm.addNestedPass<mlir::func::FuncOp>(
2227 mlir::tf_executor::CreateTFExecutorIslandCoarseningPass());
2228
2229 AddTfDeviceAssignmentPasses(pm, options);
2230
2231 // Here we perform TFRT specific optimization before standard TF optimization,
2232 // as TFRT-specific optimization may create more opportunities.
2233 pm.addNestedPass<mlir::func::FuncOp>(
2234 tfrt_compiler::CreateOptimizeTfForTfrtPass());
2235 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2236 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
2237 pm.addPass(mlir::createInlinerPass());
2238 pm.addPass(mlir::createSymbolDCEPass());
2239 pm.addNestedPass<mlir::func::FuncOp>(mlir::TF::CreateTFOptimizePass());
2240 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
2241
2242 AddTfDeviceAssignmentPasses(pm, options);
2243
2244 // After the standard pass, we now have MLIR in TF dialect, and now we convert
2245 // reference variable to resource variables, which is besteffort.
2246 pm.addPass(CreateConvertReferenceVariableToResourceVariablePass());
2247
2248 // Move the tf.Assert op to the end of the function, so that it does not
2249 // impose unnecessary control dependencies on other ops.
2250 pm.addPass(tfrt_compiler::CreateReorderTfAssertPass());
2251
2252 // Optimze the side-effects of control flow ops by examining the ops in its
2253 // callees.
2254 pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass());
2255
2256 // Remove tf.If ops' operands that are produced by tf.Const ops.
2257 pm.addPass(tfrt_compiler::CreateRemoveTfIfConstArgsPass());
2258
2259 // Merge non-side-effecting tf.If ops if their operands are the same.
2260 pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass());
2261
2262 // Deduplicate functions invoked by tf.BatchFunction with the same
2263 // shared_name
2264 pm.addPass(
2265 tfrt_compiler::CreateDeduplicateFunctionsInovkedByBatchFunctionPass());
2266
2267 // Apply standard optimization after optimizing control flow ops.
2268 pm.addPass(mlir::createInlinerPass());
2269 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
2270
2271 // TODO(b/187876545): An extra shape inference pass is added because it does
2272 // not work well with tf.Identity op that remove ref type. So we work around
2273 // by performing shape inference again after reference variable to resource
2274 // variable conversion. We should remove this after b/187876545 is fixed.
2275 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
2276
2277 pm.addNestedPass<mlir::func::FuncOp>(
2278 mlir::TFDevice::CreateLaunchToDeviceAttributePass());
2279
2280 // After all standard passes run layout optimization to assign optimal data
2281 // format for all layout sensitive operations.
2282 mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
2283 layout_optimization_options.force_data_format =
2284 options.force_data_format.getValue();
2285 // TODO(b/191304261): Folding transpose in ops is buggy in the layout
2286 // optimization pass. Disable it to avoid errors in b/191304261. This should
2287 // not affect CPU performance as it does not change the number of ops, nor
2288 // does it change the types of the ops.
2289 layout_optimization_options.skip_fold_transpose_in_ops = true;
2290 mlir::TF::CreateLayoutOptimizationPipeline(pm.nest<mlir::func::FuncOp>(),
2291 layout_optimization_options);
2292
2293 // Run canonicalization pipeline to remove unused constants and bypassed
2294 // transpose operations left in the IR after layout optimization.
2295 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2296
2297 // Decompose resource ops as resource variables will be converted to tensors
2298 // directly.
2299 if (options.decompose_resource_ops)
2300 pm.addNestedPass<mlir::func::FuncOp>(
2301 mlir::TFDevice::CreateDecomposeResourceOpsPass());
2302
2303 AddTfDeviceAssignmentPasses(pm, options);
2304
2305 pm.addNestedPass<mlir::func::FuncOp>(
2306 mlir::TF::CreateTensorDeviceCopyConversionPass());
2307
2308 // Outline auto-fusion clusters into tf_device.cluster_operations and then
2309 // convert them to functions. We currently support only tfrt fallback tensors
2310 // as operands, so we disable these passes if we can have native ops after
2311 // lowering.
2312 if (!options.enable_native_ops) {
2313 pm.addNestedPass<mlir::func::FuncOp>(CreateTfJitRtClusteringPass(
2314 options.auto_fusion_oplist, options.auto_fusion_min_cluster_size));
2315
2316 // Sink small constants into the outlined clusters to reduce the number of
2317 // arguments for each of the execute operations.
2318 auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster,
2319 mlir::ElementsAttr value) -> bool {
2320 // Ensure that cluster was formed for TFRT JIT compilation.
2321 auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2322 if (!policy || policy.getValue() != "tfrt.auto-fusion") return false;
2323
2324 // Check that TF->JitRt compiler supports constant compilation.
2325 return mlir::succeeded(IsCompilableConstant(value));
2326 };
2327
2328 pm.addNestedPass<mlir::func::FuncOp>(
2329 mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const));
2330
2331 // Outline formed JIT compiled device clusters into function.
2332 pm.addPass(CreateOutlineJitRtClustersPass());
2333 }
2334
2335 // Rewriter operation sequences to device specific fusions.
2336 DeviceNameUtils::ParsedName parsed_name;
2337
2338 // Ignore error.
2339 bool success =
2340 DeviceNameUtils::ParseFullName(options.default_device, &parsed_name);
2341 assert(success && "default device is invalid");
2342 (void)success;
2343
2344 if (parsed_name.has_type && parsed_name.type == DEVICE_GPU)
2345 pm.addNestedPass<mlir::func::FuncOp>(mlir::TF::CreateGpuOpFusionPass());
2346
2347 if (parsed_name.has_type && parsed_name.type == DEVICE_CPU)
2348 pm.addNestedPass<mlir::func::FuncOp>(
2349 mlir::TF::CreateFusedKernelMatcherPass());
2350
2351 if (options.tpu_fuse_ops) {
2352 pm.addNestedPass<mlir::func::FuncOp>(
2353 tfrt_compiler::CreateFuseTpuCompileAndExecutePass());
2354 // Remove ops for the input to _TPUCompileMlirOp, which are no longer needed
2355 // after CreateFuseTpuCompileAndExecutePass
2356 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2357 }
2358
2359 AddTfDeviceAssignmentPasses(pm, options);
2360
2361 pm.addPass(CreateLowerTFSavedModelPass(options.hoist_invariant_ops));
2362 }
2363
CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2364 void CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager &pm,
2365 const TfrtPipelineOptions &options) {
2366 CreateTFExecutorToTFPipeline(pm, options);
2367
2368 pm.addPass(CreateTfToTfrtConversionPass(options));
2369
2370 pm.addPass(CreateRemoveDeviceAttributePass());
2371
2372 // Run optimizer on the MLIR module in CoreRT dialect.
2373 if (options.enable_optimizer) {
2374 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2375 pm.addPass(mlir::createInlinerPass());
2376 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
2377 pm.addNestedPass<mlir::func::FuncOp>(
2378 tfrt_compiler::CreateInsertFallbackTensorCopyPass());
2379 }
2380 }
2381
2382 // If verbose logging is on, dump the output of each pass to a file directory,
2383 // set via env var TF_DUMP_GRAPH_PREFIX. e.g.:
2384 // export TF_DUMP_GRAPH_PREFIX=/tmp/mlir
CreateTfExecutorToTfrtPipeline(mlir::PassManager & pm,const TfrtPipelineOptions & options)2385 void CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm,
2386 const TfrtPipelineOptions &options) {
2387 if (VLOG_IS_ON(1)) {
2388 // Print the whole module after each pass, which requires disabling
2389 // multi-threading as well.
2390 pm.getContext()->disableMultithreading();
2391 pm.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
2392 /*print_module_scope=*/true));
2393 }
2394 CreateTfExecutorToTfrtPipelineHelper(pm, options);
2395 }
2396
2397 static mlir::PassRegistration<TfToTfrtConversionPass> tf_to_tfrt_pass;
2398
2399 static mlir::PassRegistration<OutlineJitRtClustersPass>
2400 tf_outline_jitrt_cluster_pass(CreateOutlineJitRtClustersPass);
2401
2402 static mlir::PassPipelineRegistration<TfrtPipelineOptions> tf_pipeline(
2403 "tf-executor-to-tfrt-pipeline",
2404 "Convert Tensorflow Executor dialect to TFRT dialect and "
2405 "also apply necessary optimization passes.",
2406 CreateTfExecutorToTfrtPipelineHelper);
2407
2408 } // namespace tensorflow
2409