• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements lowering of TF dialect to TFRT CoreRuntime ExecuteOp.
17 // This lowering pass is heavily experimental and incomplete. External code
18 // should not depend on the code here. And please do not take example on it as
19 // "the path forward" for this.
20 
21 #include <vector>
22 
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Dialect.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Pass/PassManager.h"
30 #include "mlir/Pass/PassOptions.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "mlir/Transforms/Passes.h"
33 #include "mlir/Transforms/RegionUtils.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/iterator_range.h"
36 #include "llvm/Support/FormatVariadic.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
49 #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h"
50 #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h"
51 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h"
52 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h"
53 #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h"
54 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h"
55 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
56 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
57 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
58 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
59 #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h"
60 #include "tensorflow/core/framework/tensor.h"
61 #include "tensorflow/core/framework/types.h"
62 #include "tensorflow/core/platform/tstring.h"
63 #include "tfrt/jitrt/opdefs/jitrt_ops.h"  // from @tf_runtime
64 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
65 #include "tfrt/basic_kernels/opdefs/tfrt_base.h"  // from @tf_runtime
66 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
67 #include "tfrt/core_runtime/opdefs/attributes.h"  // from @tf_runtime
68 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
69 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
70 #include "tfrt/distributed_runtime/opdefs/kernels.h"  // from @tf_runtime
71 #include "tfrt/distributed_runtime/opdefs/types.h"  // from @tf_runtime
72 #include "tfrt/test_kernels/opdefs/test_kernels.h"  // from @tf_runtime
73 
74 namespace tensorflow {
75 namespace {
76 
77 constexpr unsigned kFallbackBenefit = 1;
78 constexpr unsigned kCoreRTBenefit = 2;
79 constexpr char kGpuDeviceName[] =
80     "/job:localhost/replica:0/task:0/device:GPU:0";
81 constexpr char kTFDeviceAttr[] = "tf.device";
82 constexpr char kTFRTDeviceAttr[] = "tfrt.device";
83 constexpr char kDeviceAttr[] = "device";
84 constexpr char kHostAttr[] = "host";
85 constexpr char kDeviceTypeTpu[] = "TPU";
86 constexpr int64_t kDefaultCheapCost = 1;
87 
getDependentConversionDialects(mlir::DialectRegistry & registry)88 void getDependentConversionDialects(mlir::DialectRegistry &registry) {
89   registry.insert<tfrt::corert::CoreRTDialect, mlir::func::FuncDialect,
90                   tfrt::fallback_async::FallbackAsyncDialect,
91                   tfrt::compiler::TFRTDialect, tfrt::dist::DistributedDialect,
92                   tf_jitrt::JitRuntimeDialect>();
93 }
94 
GetFunctionInputChain(mlir::Operation * op)95 mlir::Value GetFunctionInputChain(mlir::Operation *op) {
96   auto func_op = op->getParentOfType<mlir::func::FuncOp>();
97   return func_op.getArgument(0);
98 }
99 
100 // Convert TF dialect ops to tfrt_fallback.executeop for non-side-effecting ops
101 // and tfrt_fallback.executeop.seq for side-effecting ops.
102 //
103 // For example,
104 //
105 // %0 = "tf.MatMul"(%arg0, %arg1) {device = "cpu"} :
106 //    (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
107 //
108 // is converted to
109 //
110 // %result = tfrt_fallback.executeop device("cpu")
111 //    "tf.MatMul"(%arg0, %arg1) : 1
112 //
113 class FallbackExecuteOpConversion : public mlir::ConversionPattern {
114  public:
FallbackExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,const tfrt_compiler::CostAnalysis * cost_analysis,bool tpu_lower_to_fallback,bool target_tpurt)115   FallbackExecuteOpConversion(
116       mlir::MLIRContext *context, CoreRTConverter *corert_converter,
117       tfrt_compiler::FallbackConverter *fallback_converter,
118       const tfrt_compiler::CostAnalysis *cost_analysis,
119       bool tpu_lower_to_fallback, bool target_tpurt)
120       : mlir::ConversionPattern(mlir::Pattern::MatchAnyOpTypeTag(),
121                                 kFallbackBenefit, context),
122         corert_converter_(*corert_converter),
123         fallback_converter_(*fallback_converter),
124         cost_analysis_(*cost_analysis),
125         tpu_lower_to_fallback_(tpu_lower_to_fallback),
126         target_tpurt_(target_tpurt) {}
127 
matchAndRewrite(mlir::Operation * op,ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const128   LogicalResult matchAndRewrite(
129       mlir::Operation *op, ArrayRef<mlir::Value> operands,
130       mlir::ConversionPatternRewriter &rewriter) const override {
131     if (!UseFallback(op)) return failure();
132 
133     if (target_tpurt_ && IsTpuCompileAndExecuteOps(op)) return failure();
134 
135     corert_converter_.MaterializeDerivedAttributes(op);
136 
137     mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
138     if (!device || device.getValue().empty())
139       return op->emitWarning("failed to find a non-empty 'device' attribute");
140     op->removeAttr(kDeviceAttr);
141     auto parsed_device_name =
142         corert_converter_.ParseDeviceName(device.getValue());
143     if (!parsed_device_name)
144       return op->emitWarning("failed to parse the device name");
145 
146     // Convert the function (symbol) attributes to an array of string
147     // attributes, which represents the function names.
148     llvm::SmallVector<mlir::StringAttr, 4> func_attr_keys;
149     mlir::ArrayAttr op_func_attrs =
150         corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
151 
152     // Remove the function attributes, which have already been processed.
153     for (const auto &key : func_attr_keys) op->removeAttr(key);
154 
155     mlir::ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
156     if (!op_attrs) return op->emitWarning("failed to lower attributes.");
157 
158     mlir::StringAttr op_name =
159         rewriter.getStringAttr(op->getName().getStringRef());
160 
161     // Ops with _tpu_replicate attribute are TPU ops.
162     bool is_tpu_op = op->hasAttr("_tpu_replicate") ||
163                      llvm::isa<mlir::TF::TPUReplicatedInputOp,
164                                mlir::TF::TPUReplicatedOutputOp>(op);
165 
166     if ((parsed_device_name->device_type == kDeviceTypeTpu &&
167          !tpu_lower_to_fallback_) ||
168         // Convert ops running on TPU to CoreRT dialect to prevent the creation
169         // of tfrt_fallback_async.createop for them.
170         // These ops will be encountered here only when using fallback to run
171         // TPU models, in which case, these ops are assumed to be in a function
172         // called by a TPUPartitionedCall op and will be compiled in
173         // TPUPartitionedCall op via FunctionLibraryRuntime and not be processed
174         // by BEFExecutor.
175         is_tpu_op) {
176       return ConvertToCoreRTExecuteOp(
177           op, operands, parsed_device_name->op_handler_name, op_attrs,
178           op_func_attrs, op_name, rewriter);
179     }
180 
181     // For DEVICE_CPU, DEVICE_DEFAULT, and DEVICE_TPU_SYSTEM, we use fallback.
182     // Note that TPU bridge should handle all ops that are required to be
183     // executed on TPU. So if there are still ops that are placed on TPU at this
184     // stage of lowering TF to TFRT, then these ops are supposed to be executed
185     // on host.
186     return ConvertToFallbackExecuteOp(op, operands, device, op_attrs,
187                                       op_func_attrs, op_name,
188                                       fallback_converter_, rewriter);
189   }
190 
191  private:
192   // Return true if this op can be lowered to fallback ops.
UseFallback(mlir::Operation * op) const193   bool UseFallback(mlir::Operation *op) const {
194     if (!llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) return false;
195 
196     // Below is a blocklist of ops that should not go through CoreRTExecuteOp
197     // conversion.
198     // TODO(b/173017701): have a centralized place to hold the information
199     // whether a TF op should be lowered to FallbackExecute op.
200     return !llvm::isa<mlir::TF::_TfrtSetResourceOp,
201                       mlir::TF::_TfrtGetResourceOp,
202                       // Do not use fallback on the TPU fused
203                       // compile_and_execute kernel.
204                       mlir::TF::TPUCompileMlirAndExecuteOp,
205                       // Specifically handle control flow ops.
206                       mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp,
207                       mlir::TF::StatefulPartitionedCallOp,
208                       mlir::TF::PartitionedCallOp, mlir::TF::LegacyCallOp>(op);
209   }
210 
IsTpuCompileAndExecuteOps(mlir::Operation * op) const211   bool IsTpuCompileAndExecuteOps(mlir::Operation *op) const {
212     return llvm::isa<mlir::TF::_TPUCompileMlirOp,
213                      mlir::TF::TPUCompileSucceededAssertOp,
214                      mlir::TF::TPUExecuteOp>(op);
215   }
216 
217   mlir::LogicalResult ConvertToFallbackExecuteOp(
218       mlir::Operation *op, ValueRange operands, mlir::StringAttr device,
219       mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
220       mlir::StringAttr op_name,
221       tfrt_compiler::FallbackConverter &fallback_converter,
222       mlir::ConversionPatternRewriter &rewriter) const;
223 
224   mlir::LogicalResult ConvertToCoreRTExecuteOp(
225       mlir::Operation *op, ValueRange operands, llvm::StringRef op_handler_name,
226       mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
227       mlir::StringAttr op_name,
228       mlir::ConversionPatternRewriter &rewriter) const;
229 
230   CoreRTConverter &corert_converter_;
231   tfrt_compiler::FallbackConverter &fallback_converter_;
232   const tfrt_compiler::CostAnalysis &cost_analysis_;
233   bool tpu_lower_to_fallback_;
234   bool target_tpurt_;
235 };
236 
ConvertToFallbackExecuteOp(mlir::Operation * op,ValueRange operands,mlir::StringAttr device,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,tfrt_compiler::FallbackConverter & fallback_converter,mlir::ConversionPatternRewriter & rewriter) const237 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp(
238     mlir::Operation *op, ValueRange operands, mlir::StringAttr device,
239     mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
240     mlir::StringAttr op_name,
241     tfrt_compiler::FallbackConverter &fallback_converter,
242     mlir::ConversionPatternRewriter &rewriter) const {
243   llvm::SmallVector<Type, 4> result_types(
244       op->getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
245 
246   // Convert the operands to tfrt_fallback.tf_tensor if needed.
247   llvm::SmallVector<mlir::Value, 4> new_operands;
248   if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
249           op, device.getValue(), operands, &new_operands, rewriter)))
250     return failure();
251 
252   auto fallback_key =
253       rewriter.getI64IntegerAttr(fallback_converter.GetNextFallbackKey());
254 
255   // Query cost analysis to assign costs.
256   IntegerAttr cost;
257   auto parsed_device_name =
258       corert_converter_.ParseDeviceName(device.getValue());
259   if (parsed_device_name && parsed_device_name->device_type == DEVICE_GPU) {
260     // For GPU ops, the host only needs to dispatch them to GPUs, which should
261     // be relatively cheap for the host.
262     cost = rewriter.getI64IntegerAttr(kDefaultCheapCost);
263   } else {
264     cost = rewriter.getI64IntegerAttr(
265         cost_analysis_.GetCost(op, fallback_key.getInt()));
266   }
267 
268   if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
269     auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
270         op->getLoc(), result_types, new_operands, device, op_attrs,
271         op_func_attrs, fallback_key, op_name, cost);
272     fallback_converter.RegisterFallbackOp(new_op);
273     rewriter.replaceOp(op, new_op.results());
274   } else {
275     auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
276     auto out_chain = in_chain;
277 
278     if (tfrt_compiler::IsTensorArrayOp(op)) {
279       // If it is a tensor array op, we don't need to use
280       // tfrt_fallback_async.executeop.seq because its operands/results already
281       // take care of control dependencies.
282       auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
283           op->getLoc(), result_types, new_operands, device, op_attrs,
284           op_func_attrs, fallback_key, op_name, cost);
285       fallback_converter.RegisterFallbackOp(new_op);
286       rewriter.replaceOp(op, new_op.results());
287     } else {
288       // Create tfrt_fallback.executeop.seq if it is a side-effecting op.
289       auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOpSeq>(
290           op->getLoc(), corert_converter_.chain_type(), result_types, in_chain,
291           new_operands, device, op_attrs, op_func_attrs, fallback_key, op_name,
292           cost);
293       fallback_converter.RegisterFallbackOp(new_op);
294       rewriter.replaceOp(op, new_op.results());
295       out_chain = new_op.out_op_chain();
296     }
297 
298     // Register the converted op so that it can be retrieved by successors.
299     corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
300   }
301 
302   return success();
303 }
304 
ConvertToCoreRTExecuteOp(mlir::Operation * op,ValueRange operands,llvm::StringRef op_handler_name,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,mlir::ConversionPatternRewriter & rewriter) const305 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToCoreRTExecuteOp(
306     mlir::Operation *op, ValueRange operands, llvm::StringRef op_handler_name,
307     mlir::ArrayAttr op_attrs, mlir::ArrayAttr op_func_attrs,
308     mlir::StringAttr op_name, mlir::ConversionPatternRewriter &rewriter) const {
309   llvm::SmallVector<Type, 4> result_types(
310       op->getNumResults(), rewriter.getType<tfrt::corert::TensorHandleType>());
311 
312   // Convert the operands to tensorhandles if needed.
313   llvm::SmallVector<mlir::Value, 4> new_operands;
314   if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
315           op, operands, &new_operands, rewriter)))
316     return failure();
317 
318   // Get the op handler, or create one if there does not exist one. Note that
319   // ConvertOpHandler changes internal state so it can only be called if the
320   // rewrite is guaranteed to succeed afterwards.
321   auto op_handler =
322       corert_converter_.ConvertOpHandler(op, op_handler_name, &rewriter);
323   if (!op_handler) return failure();
324 
325   if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
326     auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
327         op->getLoc(), result_types, op_handler, new_operands, op_attrs,
328         op_func_attrs, op_name);
329     rewriter.replaceOp(op, new_op.results());
330   } else {
331     // Create corert.executeop.seq if it is a side-effecting op.
332     auto new_op = rewriter.create<tfrt::corert::ExecuteOpSeq>(
333         op->getLoc(), corert_converter_.chain_type(), result_types, op_handler,
334         corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
335         op_attrs, op_func_attrs, op_name);
336     rewriter.replaceOp(op, new_op.results());
337 
338     // Register the converted op so that it can be retrieved by successors.
339     corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
340   }
341 
342   return success();
343 }
344 
345 class FallbackConstOpConversion
346     : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
347  public:
FallbackConstOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)348   FallbackConstOpConversion(mlir::MLIRContext *context,
349                             CoreRTConverter *corert_converter)
350       : mlir::OpConversionPattern<mlir::TF::ConstOp>(context),
351         corert_converter_(*corert_converter) {}
352 
matchAndRewrite(mlir::TF::ConstOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const353   mlir::LogicalResult matchAndRewrite(
354       mlir::TF::ConstOp op, OpAdaptor adaptor,
355       mlir::ConversionPatternRewriter &rewriter) const override {
356     // Some data types are handled separately using a fast path.
357     if (corert_converter_.IsSupportedNumericDType(op.dtype()) ||
358         op.dtype().isa<mlir::TF::StringType>())
359       return failure();
360 
361     // For other data types that do not have a fast path (eg. quantized types),
362     // we convert it to serialized tensor proto.
363 
364     tensorflow::TensorProto tensor_proto;
365     auto status = ConvertToTensorProto(op.value(), &tensor_proto);
366     if (!status.ok()) return op.emitError(status.error_message());
367 
368     rewriter.replaceOpWithNewOp<tfrt::fallback_async::ConstTensorProtoOp>(
369         op, rewriter.getType<tfrt::fallback::TFTensorType>(),
370         tensor_proto.SerializeAsString());
371 
372     return success();
373   }
374 
375  private:
376   CoreRTConverter &corert_converter_;
377 };
378 
379 class FallbackSetResourceOp
380     : public mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp> {
381  public:
FallbackSetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)382   FallbackSetResourceOp(mlir::MLIRContext *context,
383                         CoreRTConverter *corert_converter)
384       : mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp>(context),
385         corert_converter_(*corert_converter) {}
386 
matchAndRewrite(mlir::TF::_TfrtSetResourceOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const387   mlir::LogicalResult matchAndRewrite(
388       mlir::TF::_TfrtSetResourceOp op, OpAdaptor adaptor,
389       mlir::ConversionPatternRewriter &rewriter) const override {
390     mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
391     if (!device || device.getValue().empty())
392       return op->emitWarning("failed to find a non-empty 'device' attribute");
393 
394     // Currently the static resource is always on host CPU.
395     //
396     // TODO(chky): Support resource on other devices.
397     llvm::SmallVector<mlir::Value, 4> new_operands;
398     if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
399             op, tfrt_compiler::GetDefaultCpuDeviceName(), adaptor.getOperands(),
400             &new_operands, rewriter)))
401       return failure();
402 
403     assert(new_operands.size() == 1);
404 
405     auto new_op = rewriter.create<tfrt::fallback_async::SetResourceOp>(
406         op.getLoc(), corert_converter_.chain_type(),
407         corert_converter_.GetLocalSideEffectChain(op, &rewriter),
408         new_operands[0], device.getValue(), op.index());
409 
410     // Register the converted op so that it can be retrieved by successors.
411     corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_ch());
412 
413     rewriter.eraseOp(op);
414 
415     return success();
416   }
417 
418  private:
419   CoreRTConverter &corert_converter_;
420 };
421 
422 class FallbackGetResourceOp
423     : public mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp> {
424  public:
FallbackGetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)425   FallbackGetResourceOp(mlir::MLIRContext *context,
426                         CoreRTConverter *corert_converter)
427       : mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp>(context),
428         corert_converter_(*corert_converter) {}
429 
matchAndRewrite(mlir::TF::_TfrtGetResourceOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const430   mlir::LogicalResult matchAndRewrite(
431       mlir::TF::_TfrtGetResourceOp op, OpAdaptor adaptor,
432       mlir::ConversionPatternRewriter &rewriter) const override {
433     mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
434     if (!device || device.getValue().empty())
435       return op->emitWarning("failed to find a non-empty 'device' attribute");
436 
437     llvm::SmallVector<mlir::Type, 4> result_types(
438         op.getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
439 
440     auto ready_chain = rewriter.create<tfrt::compiler::NewChainOp>(
441         op.getLoc(), rewriter.getType<tfrt::compiler::ChainType>());
442 
443     auto new_op = rewriter.create<tfrt::fallback_async::GetResourceOp>(
444         op.getLoc(), corert_converter_.chain_type(), result_types, ready_chain,
445         device.getValue(), op.indices());
446 
447     rewriter.replaceOp(op, new_op.results());
448 
449     return success();
450   }
451 
452  private:
453   CoreRTConverter &corert_converter_;
454 };
455 
456 // Convert a tf_device.remote_run op to a tfrt_dist.remote_execute_func op.
457 //
458 // For example,
459 //
460 // %result = tf_device.remote_run "/job:worker/replica:0/task:0"
461 //   @remote_func(%arg0) : (tensor<i32>) -> (tensor<i32>)
462 //
463 // is converted to
464 //
465 // %0 = tfrt_dist.get_distributed_context
466 // %1 = tfrt_dist.get_task_handle %0
467 //  {task_name = "/job:worker/replica:0/task:0"}
468 // %2 = tfrt_dist.test_create_remote_chain_manager %0
469 // %3 = tfrt_dist.get_chain_for_task_handle %in_chain, %2, %1
470 // %out_op_chain, %results:2 = tfrt_dist.remote_execute_func[%in_chain, %0, %1]
471 //  @remote_func(%3, %arg1): (!tfrt_dist.remote_object_id, !corert.tensorhandle)
472 //  -> (!tfrt_dist.remote_object_id, !corert.tensorhandle)
473 // %4 = tfrt_dist.set_chain_for_task_handle %out_op_chain, %2, %1, %results#0
474 //
475 class TFDeviceRemoteRunOpConversion
476     : public mlir::OpConversionPattern<tf_device::RemoteRunOp> {
477  public:
TFDeviceRemoteRunOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter)478   TFDeviceRemoteRunOpConversion(mlir::MLIRContext *context,
479                                 mlir::TypeConverter *type_converter,
480                                 CoreRTConverter *corert_converter)
481       : mlir::OpConversionPattern<tf_device::RemoteRunOp>(context,
482                                                           kCoreRTBenefit),
483         type_converter_(*type_converter),
484         corert_converter_(*corert_converter) {}
485 
matchAndRewrite(tf_device::RemoteRunOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const486   LogicalResult matchAndRewrite(
487       tf_device::RemoteRunOp op, OpAdaptor adaptor,
488       ConversionPatternRewriter &rewriter) const override {
489     mlir::Value distributed_context =
490         corert_converter_.GetDistributedContext(op.getOperation(), &rewriter);
491     mlir::Value in_op_chain =
492         corert_converter_.GetLocalSideEffectChain(op, &rewriter);
493     mlir::Value task_handle = corert_converter_.GetTaskHandle(
494         op.getOperation(), op.host(), &rewriter);
495     mlir::Value remote_chain_mgr =
496         corert_converter_.GetRemoteChainManager(op, &rewriter);
497     mlir::Type remote_obj_id_ty =
498         rewriter.getType<tfrt::dist::RemoteObjectIdType>();
499     ModuleOp module = op->getParentOfType<ModuleOp>();
500     SymbolTable symtab(module);
501     func::FuncOp callee = symtab.lookup<func::FuncOp>(op.callee());
502     if (!callee) {
503       op.emitOpError("callee function ") << op.callee() << " is not found";
504       return failure();
505     }
506     StringAttr host = callee->getAttrOfType<StringAttr>(kHostAttr);
507     if (!host) {
508       op.emitOpError("callee function ")
509           << op.callee() << " should have the host attribute";
510       return failure();
511     }
512 
513     llvm::SmallVector<mlir::Value, 4> arguments;
514     // The first argument of the remote function should be a remote chain which
515     // is added to the function signature when it is lowered from TF dialect to
516     // TFRT dialect.
517     arguments.push_back(corert_converter_.GetRemoteSideEffectChain(
518         op, host.getValue(), &rewriter));
519     for (mlir::Value argument : op.callee_args()) {
520       arguments.push_back(argument);
521     }
522 
523     llvm::SmallVector<mlir::Type, 4> result_types;
524     // The first result of the remote function should be a remote chain which
525     // is added to the function signature when it is lowered from TF dialect to
526     // TFRT dialect.
527     result_types.push_back(remote_obj_id_ty);
528     for (mlir::Type type : op.getResultTypes()) {
529       (void)type_converter_.convertType(type, result_types);
530     }
531     auto remote_execute_func_op =
532         rewriter.create<tfrt::dist::RemoteExecuteFuncOp>(
533             op.getLoc(), corert_converter_.chain_type(), result_types,
534             in_op_chain, distributed_context, task_handle, op.callee(),
535             arguments);
536     rewriter.replaceOp(op, remote_execute_func_op.results().drop_front(1));
537 
538     auto set_chain_op = rewriter.create<tfrt::dist::SetChainForTaskHandleOp>(
539         op.getLoc(), corert_converter_.chain_type(),
540         remote_execute_func_op.out_op_chain(), remote_chain_mgr, task_handle,
541         remote_execute_func_op.results().front());
542     corert_converter_.RegisterLocalSideEffectChain(op,
543                                                    set_chain_op.out_op_chain());
544 
545     return success();
546   }
547 
548  private:
549   mlir::TypeConverter &type_converter_;
550   CoreRTConverter &corert_converter_;
551 };
552 
553 // Lowers a tf.BatchFunction to tfrt_fallback.batch_function.
554 class FallbackBatchFunctionOpConversion
555     : public mlir::OpConversionPattern<mlir::TF::BatchFunctionOp> {
556  public:
FallbackBatchFunctionOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)557   FallbackBatchFunctionOpConversion(mlir::MLIRContext *context,
558                                     CoreRTConverter *corert_converter)
559       : mlir::OpConversionPattern<mlir::TF::BatchFunctionOp>(context,
560                                                              kFallbackBenefit),
561         corert_converter_(*corert_converter) {}
562 
matchAndRewrite(mlir::TF::BatchFunctionOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const563   LogicalResult matchAndRewrite(
564       mlir::TF::BatchFunctionOp op, OpAdaptor adaptor,
565       ConversionPatternRewriter &rewriter) const override {
566     corert_converter_.MaterializeDerivedAttributes(op);
567 
568     // Remove the device attribute for fallback, as currently fallback will
569     // select device automatically.
570     //
571     // TODO(chky): The device attribute should be passed explicitly. This can be
572     // once we change the kernel implementation to choose device based on
573     // attributes.
574     op->removeAttr(rewriter.getStringAttr(kDeviceAttr));
575 
576     SymbolRefAttr f = op.fAttr();
577 
578     llvm::SmallVector<NamedAttribute, 12> attr_array;
579     for (auto &key_and_value : op->getAttrs()) {
580       StringRef name = key_and_value.getName();
581       if (name == "_output_shapes" || name == "f") {
582         continue;
583       }
584       attr_array.push_back(key_and_value);
585     }
586     ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(attr_array);
587     if (!op_attrs) return op.emitWarning("failed to lower attributes.");
588 
589     llvm::SmallVector<Type, 4> result_types;
590     for (auto type : op.getResultTypes()) {
591       if (failed(corert_converter_.convertType(type, result_types)))
592         return failure();
593     }
594 
595     llvm::SmallVector<mlir::Value, 4> new_operands;
596     if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
597             op, adaptor.getOperands(), &new_operands, rewriter)))
598       return failure();
599 
600     auto new_op = rewriter.create<tfrt::fallback_async::BatchFunctionOp>(
601         op.getLoc(), corert_converter_.chain_type(), result_types,
602         corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
603         f, op_attrs);
604     rewriter.replaceOp(op, new_op.results());
605 
606     // Register the converted op so that it can be retrieved by successors.
607     corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
608 
609     return success();
610   }
611 
612  private:
613   CoreRTConverter &corert_converter_;
614 };
615 
616 // Lower a tf.Const op that creates a string tensor to a native
617 // corert.create_string_tensor op.
618 class CoreRTConstDenseTensorOpConversion
619     : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
620  public:
CoreRTConstDenseTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)621   CoreRTConstDenseTensorOpConversion(mlir::MLIRContext *context,
622                                      CoreRTConverter *corert_converter)
623       : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
624         corert_converter_(*corert_converter) {}
625 
matchAndRewrite(mlir::TF::ConstOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const626   LogicalResult matchAndRewrite(
627       mlir::TF::ConstOp op, OpAdaptor adaptor,
628       ConversionPatternRewriter &rewriter) const override {
629     if (!corert_converter_.IsSupportedNumericDType(op.dtype()))
630       return failure();
631 
632     // Only CPU ops can be lowered using this conversion. If there is no device
633     // assignment, this op is treated as a CPU op and can be lowered.
634     if (auto parsed_device_name = corert_converter_.ParseDeviceName(op))
635       if (parsed_device_name->device_type != DEVICE_CPU) return failure();
636 
637     auto new_op = rewriter.create<tfrt::corert::ConstDenseTensorOp>(
638         op.getLoc(), corert_converter_.tensor_handle_type(),
639         op.value().cast<DenseElementsAttr>());
640     rewriter.replaceOp(op, new_op->getResult(0));
641     return success();
642   }
643 
644  private:
645   CoreRTConverter &corert_converter_;
646 };
647 
648 // Convert the FuncOp with the following changes to meet TFRT's requirements:
649 // 1) Convert types for the arguments and the results.
650 // 2) Add a chain to the arguments.
651 // 3) Add a chain to the results for side-effects.
652 // 4) If any argument has a tf.device attribute, change the attribute name
653 //    to tfrt.device.
654 // 5) If any result has a tf.device attribute, change the attribute name
655 //    to tfrt.device.
656 //
657 // The input chain is used to signal visibility of all side-effects before
658 // calling this function. The output chain is used to signal visibility of all
659 // side-effects of this function.
660 class TFRTFuncOpSignatureConversion
661     : public mlir::OpConversionPattern<mlir::func::FuncOp> {
662  public:
TFRTFuncOpSignatureConversion(mlir::MLIRContext * ctx,mlir::TypeConverter * type_converter)663   TFRTFuncOpSignatureConversion(mlir::MLIRContext *ctx,
664                                 mlir::TypeConverter *type_converter)
665       : OpConversionPattern(ctx), type_converter_(*type_converter) {}
666 
matchAndRewrite(mlir::func::FuncOp func_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const667   LogicalResult matchAndRewrite(
668       mlir::func::FuncOp func_op, OpAdaptor adaptor,
669       ConversionPatternRewriter &rewriter) const override {
670     mlir::FunctionType type = func_op.getFunctionType();
671 
672     // Convert the original function arguments.
673     mlir::TypeConverter::SignatureConversion converted_signature(
674         type.getNumInputs());
675     // Add a chain as the first input.
676     converted_signature.addInputs(
677         rewriter.getType<tfrt::compiler::ChainType>());
678 
679     // Convert the original function results.
680     SmallVector<Type, 1> converted_results;
681     // Add a chain as the first result.
682     converted_results.push_back(rewriter.getType<tfrt::compiler::ChainType>());
683 
684     if (failed(type_converter_.convertSignatureArgs(type.getInputs(),
685                                                     converted_signature)) ||
686         failed(type_converter_.convertTypes(type.getResults(),
687                                             converted_results)) ||
688         failed(rewriter.convertRegionTypes(&func_op.getBody(), type_converter_,
689                                            &converted_signature))) {
690       return failure();
691     }
692 
693     llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs;
694     // The first input, which is a chain added by this pass, has no attribute.
695     arg_attrs.emplace_back();
696     func_op.getAllArgAttrs(arg_attrs);
697     // If any argument has a tf.device attribute, change the attribute name to
698     // tfrt.device.
699     for (auto &arg_attr : arg_attrs) {
700       mlir::NamedAttrList arg_attr_list(arg_attr);
701       if (Attribute device = arg_attr_list.erase(kTFDeviceAttr)) {
702         arg_attr_list.set(kTFRTDeviceAttr, device);
703         arg_attr = arg_attr_list.getDictionary(device.getContext());
704       }
705     }
706     arg_attrs.resize(converted_signature.getConvertedTypes().size());
707 
708     // The first result, which is a chain added by this pass, has no attribute.
709     llvm::SmallVector<mlir::DictionaryAttr, 4> result_attrs;
710     result_attrs.emplace_back();
711     func_op.getAllResultAttrs(result_attrs);
712     // If any result has a tf.device attribute, change the attribute name to
713     // tfrt.device.
714     for (auto &result_attr : result_attrs) {
715       mlir::NamedAttrList result_attr_list(result_attr);
716       if (Attribute device = result_attr_list.erase(kTFDeviceAttr)) {
717         result_attr_list.set(kTFRTDeviceAttr, device);
718         result_attr = result_attr_list.getDictionary(device.getContext());
719       }
720     }
721     result_attrs.resize(converted_results.size());
722 
723     // Update the function signature in-place.
724     rewriter.updateRootInPlace(func_op, [&] {
725       func_op.setType(mlir::FunctionType::get(
726           func_op.getContext(), converted_signature.getConvertedTypes(),
727           converted_results));
728       func_op.setAllArgAttrs(arg_attrs);
729       func_op.setAllResultAttrs(result_attrs);
730     });
731 
732     return success();
733   }
734 
735  private:
736   mlir::TypeConverter &type_converter_;
737 };
738 
739 // Lower a tf.Const op that creates a string tensor to a native
740 // corert.create_string_tensor op.
741 class CoreRTConstStringTensorOpConversion
742     : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
743  public:
CoreRTConstStringTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)744   CoreRTConstStringTensorOpConversion(mlir::MLIRContext *context,
745                                       CoreRTConverter *corert_converter)
746       : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
747         corert_converter_(*corert_converter) {}
748 
matchAndRewrite(mlir::TF::ConstOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const749   LogicalResult matchAndRewrite(
750       mlir::TF::ConstOp op, OpAdaptor adaptor,
751       ConversionPatternRewriter &rewriter) const override {  // NOLINT
752     if (!op.dtype().isa<mlir::TF::StringType>()) return failure();
753 
754     DenseStringElementsAttr attr = op.value().cast<DenseStringElementsAttr>();
755 
756     llvm::SmallVector<Attribute, 4> values;
757     values.reserve(attr.getNumElements());
758     for (const auto &element : attr.getRawStringData())
759       values.push_back(rewriter.getStringAttr(
760           llvm::StringRef(element.data(), element.size())));
761 
762     // Create the shape attribute from the tensor shape.
763     ArrayRef<int64_t> shape = op.value().getType().getShape();
764     llvm::SmallVector<mlir::Attribute, 4> dims;
765     dims.reserve(shape.size());
766     auto i64_type = rewriter.getIntegerType(64);
767     for (auto dim : shape)
768       dims.push_back(rewriter.getIntegerAttr(i64_type, dim));
769 
770     auto new_op = rewriter.create<tfrt::corert::ConstStringTensorOp>(
771         op.getLoc(), corert_converter_.tensor_handle_type(),
772         rewriter.getArrayAttr(dims), rewriter.getArrayAttr(values));
773 
774     rewriter.replaceOp(op, new_op.result());
775 
776     return success();
777   }
778 
779  private:
780   CoreRTConverter &corert_converter_;
781 };
782 
783 // Convert TF dialect operations with no side effects to CoreRT ExecuteOp. For
784 // example,
785 //
786 // %0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
787 //    (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
788 //
789 // is converted to
790 //
791 // %result = corert.executeop(%device)
792 //    "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
793 //    (!corert.tensorhandle, !corert.tensorhandle) -> !corert.tensorhandle
794 //
795 // Note that it will fail to match if some attributes are not supported.
796 template <typename TF_Op>
797 class CoreRTExecuteOpConversion : public mlir::OpConversionPattern<TF_Op> {
798  public:
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)799   CoreRTExecuteOpConversion(mlir::MLIRContext *context,
800                             CoreRTConverter *corert_converter)
801       : CoreRTExecuteOpConversion(context, corert_converter, "") {}
802 
803   // If device_name is not empty, only ops that are using this device is lowered
804   // using CoreRTExecuteOpConversion.
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,llvm::StringRef device_name)805   CoreRTExecuteOpConversion(mlir::MLIRContext *context,
806                             CoreRTConverter *corert_converter,
807                             llvm::StringRef device_name)
808       : mlir::OpConversionPattern<TF_Op>(context, kCoreRTBenefit),
809         corert_converter_(*corert_converter),
810         device_name_(device_name) {}
811 
matchAndRewrite(TF_Op op,typename TF_Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const812   LogicalResult matchAndRewrite(
813       TF_Op op, typename TF_Op::Adaptor adaptor,
814       ConversionPatternRewriter &rewriter) const override {
815     auto parsed_device_name = corert_converter_.ParseDeviceName(op);
816     // Return failure and emit warning if there is no device assignment.
817     if (!parsed_device_name) {
818       return op->emitWarning(
819           "failed to retrieve valid device when converting to "
820           "corert.executeop");
821     }
822 
823     // If device_name is specified, check the device of this op first.
824     if (!device_name_.empty()) {
825       // Skip if it does not match the specified device.
826       if (parsed_device_name->device_name != device_name_) return failure();
827     }
828 
829     mlir::StringAttr op_name = rewriter.getStringAttr(op.getOperationName());
830 
831     llvm::SmallVector<Type, 4> result_types;
832     for (auto type : op.getOperation()->getResultTypes()) {
833       if (failed(corert_converter_.convertType(type, result_types)))
834         return failure();
835     }
836 
837     corert_converter_.MaterializeDerivedAttributes(op);
838 
839     // Convert the function (symbol) attributes to an array of string
840     // attributes, which represents the function names.
841     llvm::SmallVector<mlir::StringAttr, 4> func_attr_keys;
842     ArrayAttr op_func_attrs =
843         corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
844 
845     // Remove the function attributes, which have already been processed.
846     for (const auto &key : func_attr_keys) op->removeAttr(key);
847 
848     ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
849     if (!op_attrs) return op.emitError("failed to lower attributes.");
850 
851     llvm::SmallVector<mlir::Value, 4> new_operands;
852     if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
853             op, adaptor.getOperands(), &new_operands, rewriter)))
854       return failure();
855 
856     // Get the op handler, or create one if there does not exist one. Note that
857     // ConvertOpHandler changes internal state so it can only be called if the
858     // rewrite is guaranteed to succeed afterwards.
859     auto op_handler = corert_converter_.ConvertOpHandler(
860         op, parsed_device_name->op_handler_name, &rewriter);
861     if (!op_handler) return failure();
862 
863     auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
864         op.getLoc(), result_types, op_handler, new_operands, op_attrs,
865         op_func_attrs, op_name);
866 
867     rewriter.replaceOp(op, new_op.results());
868     return success();
869   }
870 
871  private:
872   CoreRTConverter &corert_converter_;
873   llvm::StringRef device_name_;
874 };
875 
ConvertFunctionCallOperands(mlir::Operation * op,ValueRange operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter,bool func_use_fallback_tensor)876 LogicalResult ConvertFunctionCallOperands(
877     mlir::Operation *op, ValueRange operands,
878     llvm::SmallVectorImpl<mlir::Value> *new_operands,
879     mlir::ConversionPatternRewriter &rewriter, bool func_use_fallback_tensor) {
880   if (func_use_fallback_tensor) {
881     // TODO(b/182232457): Support other devices.
882     return tfrt_compiler::ConvertFallbackOperands(
883         op, tfrt_compiler::GetDefaultCpuDeviceName(), operands, new_operands,
884         rewriter);
885   } else {
886     return tfrt_compiler::ConvertCoreRTOperands(op, operands, new_operands,
887                                                 rewriter);
888   }
889 }
890 
891 // Convert TF call ops (eg. StatefulPartitionedCall) to tfrt.call.
892 template <typename CallOp>
893 class TFRTCallOpConversion : public mlir::OpConversionPattern<CallOp> {
894  public:
TFRTCallOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)895   TFRTCallOpConversion(mlir::MLIRContext *context,
896                        mlir::TypeConverter *type_converter,
897                        CoreRTConverter *corert_converter,
898                        bool func_use_fallback_tensor)
899       : mlir::OpConversionPattern<CallOp>(context),
900         type_converter_(*type_converter),
901         corert_converter_(*corert_converter),
902         func_use_fallback_tensor_(func_use_fallback_tensor) {}
903 
matchAndRewrite(CallOp op,typename CallOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const904   LogicalResult matchAndRewrite(
905       CallOp op, typename CallOp::Adaptor adaptor,
906       ConversionPatternRewriter &rewriter) const override {
907     auto callee =
908         op.getCallableForCallee().template dyn_cast<mlir::SymbolRefAttr>();
909     if (!callee) return failure();
910 
911     llvm::SmallVector<mlir::Value, 4> new_operands;
912     new_operands.push_back(
913         corert_converter_.GetLocalSideEffectChain(op, &rewriter));
914     // Currently the converted functions always use !corert.tensorhandle types
915     // for tensor arguments and results.
916     //
917     // TODO(b/175881042): We should avoid the tensor conversion here if the
918     // operand is !tfrt_fallback.tf_tensor, and it is also used as fallback
919     // tensor inside the callee function.
920     if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(),
921                                                  &new_operands, rewriter,
922                                                  func_use_fallback_tensor_)))
923       return failure();
924 
925     llvm::SmallVector<mlir::Type, 4> result_types;
926     result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
927     for (auto type : op.getOperation()->getResultTypes()) {
928       if (failed(type_converter_.convertType(type, result_types)))
929         return failure();
930     }
931 
932     auto new_op = rewriter.create<tfrt::compiler::CallOp>(
933         op.getLoc(), result_types, callee.getRootReference().getValue(),
934         new_operands);
935     rewriter.replaceOp(op, new_op.getResults().drop_front());
936 
937     if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
938       // Register the converted op so that it can be retrieved by successors.
939       // TODO(chky): Add OpTraits or OpInterface, rather than assume first
940       // result is a chain.
941       corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
942     }
943 
944     return success();
945   }
946 
947  private:
948   mlir::TypeConverter &type_converter_;
949   CoreRTConverter &corert_converter_;
950   bool func_use_fallback_tensor_;
951 };
952 
953 // Convert func ReturnOp to tfrt.return.
954 //
955 // TODO(chky): conversion to tfrt kernels should come from a common tf_to_tfrt
956 // library.
957 class TFRTReturnOpConversion
958     : public mlir::OpConversionPattern<mlir::func::ReturnOp> {
959  public:
TFRTReturnOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)960   TFRTReturnOpConversion(mlir::MLIRContext *context,
961                          CoreRTConverter *corert_converter,
962                          bool func_use_fallback_tensor)
963       : mlir::OpConversionPattern<mlir::func::ReturnOp>(context),
964         corert_converter_(*corert_converter),
965         func_use_fallback_tensor_(func_use_fallback_tensor) {}
966 
matchAndRewrite(mlir::func::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const967   LogicalResult matchAndRewrite(
968       mlir::func::ReturnOp op, OpAdaptor adaptor,
969       ConversionPatternRewriter &rewriter) const override {
970     llvm::SmallVector<mlir::Value, 2> new_operands;
971 
972     // Currently in mlir::TF::SideEffectAnalysis, all terminator ops are treated
973     // as side-effect ops and they have predecessors but not successors.
974     //
975     // TODO(chky): ReturnOp has no side effect. Once the special handling in
976     // mlir::TF::SideEffectAnalysis is removed, the chains should come from
977     // side-effecting ops with no successors in the function.
978     new_operands.push_back(
979         corert_converter_.GetLocalSideEffectChain(op, &rewriter));
980     if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(),
981                                                  &new_operands, rewriter,
982                                                  func_use_fallback_tensor_)))
983       return failure();
984 
985     rewriter.replaceOpWithNewOp<tfrt::compiler::ReturnOp>(op, new_operands);
986     return success();
987   }
988 
989  private:
990   CoreRTConverter &corert_converter_;
991   bool func_use_fallback_tensor_;
992 };
993 
994 // Convert tf.Case op to tfrt.Case.
995 //
996 // TF dialect:
997 // %outputs = "tf.Case"(%arg, ...) { branches = [@branch0, @branch1], ...}
998 //
999 // lowered TFRT CoreRT dialect:
1000 // %idx_int = corert.tensorhandle_to_int32 %idx
1001 // %out_chain, %outputs = tfrt.case %idx_int [@branch0, @branch1] (%chain, ...)
1002 class TFRTCaseOpConversion : public mlir::OpConversionPattern<TF::CaseOp> {
1003  public:
TFRTCaseOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)1004   TFRTCaseOpConversion(mlir::MLIRContext *context,
1005                        mlir::TypeConverter *type_converter,
1006                        CoreRTConverter *corert_converter,
1007                        bool func_use_fallback_tensor)
1008       : mlir::OpConversionPattern<TF::CaseOp>(context),
1009         type_converter_(*type_converter),
1010         corert_converter_(*corert_converter),
1011         func_use_fallback_tensor_(func_use_fallback_tensor) {}
1012 
matchAndRewrite(TF::CaseOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1013   LogicalResult matchAndRewrite(
1014       TF::CaseOp op, OpAdaptor adaptor,
1015       ConversionPatternRewriter &rewriter) const override {
1016     mlir::ArrayAttr branches = op.branches();
1017 
1018     llvm::SmallVector<mlir::Type, 4> result_types;
1019     result_types.push_back(corert_converter_.chain_type());
1020     for (mlir::Type type : op->getResultTypes()) {
1021       if (failed(type_converter_.convertType(type, result_types)))
1022         return failure();
1023     }
1024 
1025     llvm::SmallVector<mlir::Value, 4> branch_operands;
1026     branch_operands.push_back(
1027         corert_converter_.GetLocalSideEffectChain(op, &rewriter));
1028     if (mlir::failed(ConvertFunctionCallOperands(
1029             op, adaptor.getOperands().drop_front(), &branch_operands, rewriter,
1030             func_use_fallback_tensor_)))
1031       return failure();
1032 
1033     mlir::Value index_operand = adaptor.getOperands()[0];
1034     // TODO(b/182233401): Support TF tensor; remove the conversion op here.
1035     if (index_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
1036       // TODO(b/182232457): Support other devices.
1037       index_operand =
1038           rewriter
1039               .create<
1040                   tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
1041                   op.getLoc(),
1042                   rewriter.getType<tfrt::corert::TensorHandleType>(),
1043                   adaptor.getOperands()[0],
1044                   tfrt_compiler::GetDefaultCpuDeviceName())
1045               .getResult(0);
1046     }
1047     if (!index_operand.getType().isa<tfrt::corert::TensorHandleType>())
1048       return op.emitError(
1049           "branch index operand is expected to be a TensorHandle.");
1050     mlir::Value index_value =
1051         rewriter.create<tfrt::corert::TensorHandleToIntOp>(
1052             op.getLoc(), rewriter.getI32Type(), index_operand);
1053 
1054     auto new_op = rewriter.create<tfrt::compiler::CaseOp>(
1055         op.getLoc(), result_types, index_value, branches, branch_operands);
1056 
1057     rewriter.replaceOp(op, new_op.branch_outputs().drop_front());
1058     return success();
1059   }
1060 
1061  private:
1062   mlir::TypeConverter &type_converter_;
1063   CoreRTConverter &corert_converter_;
1064   bool func_use_fallback_tensor_;
1065 };
1066 
GetPredicate(mlir::Operation * op,mlir::Value cond_operand,mlir::ConversionPatternRewriter & rewriter)1067 static mlir::Value GetPredicate(mlir::Operation *op, mlir::Value cond_operand,
1068                                 mlir::ConversionPatternRewriter &rewriter) {
1069   if (!cond_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
1070     cond_operand = tfrt_compiler::ConvertCoreRTTensorHandleToFallbackTensor(
1071         op->getLoc(), tfrt_compiler::GetDefaultCpuDeviceName(), cond_operand,
1072         rewriter);
1073     if (!cond_operand) {
1074       op->emitWarning("failed to convert the cond operand to fallback tensor.");
1075       return {};
1076     }
1077   }
1078 
1079   return rewriter.create<tfrt::fallback_async::PredicateOp>(
1080       op->getLoc(), rewriter.getI1Type(), cond_operand);
1081 }
1082 
1083 class TFRTCondOpConversion : public mlir::OpConversionPattern<mlir::TF::IfOp> {
1084  public:
TFRTCondOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)1085   TFRTCondOpConversion(mlir::MLIRContext *context,
1086                        mlir::TypeConverter *type_converter,
1087                        CoreRTConverter *corert_converter,
1088                        bool func_use_fallback_tensor)
1089       : mlir::OpConversionPattern<TF::IfOp>(context),
1090         type_converter_(*type_converter),
1091         corert_converter_(*corert_converter),
1092         func_use_fallback_tensor_(func_use_fallback_tensor) {}
1093 
matchAndRewrite(mlir::TF::IfOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1094   mlir::LogicalResult matchAndRewrite(
1095       mlir::TF::IfOp op, OpAdaptor adaptor,
1096       mlir::ConversionPatternRewriter &rewriter) const override {
1097     mlir::FlatSymbolRefAttr then_branch = op.then_branchAttr();
1098     mlir::FlatSymbolRefAttr else_branch = op.else_branchAttr();
1099 
1100     llvm::SmallVector<Type, 4> result_types;
1101     result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
1102     for (mlir::Type type : op.getOperation()->getResultTypes()) {
1103       if (failed(type_converter_.convertType(type, result_types)))
1104         return failure();
1105     }
1106 
1107     // Convert the cond tensor to a boolean value so that it can be used by
1108     // tfrt.cond kernel.
1109     auto bool_cond = GetPredicate(op, adaptor.getOperands()[0], rewriter);
1110     if (!bool_cond) return failure();
1111 
1112     llvm::SmallVector<mlir::Value, 4> new_operands;
1113     // The first arg of converted branch functions should be !tfrt.chain.
1114     new_operands.push_back(
1115         corert_converter_.GetLocalSideEffectChain(op, &rewriter));
1116 
1117     if (mlir::failed(ConvertFunctionCallOperands(
1118             op, adaptor.getOperands().drop_front(), &new_operands, rewriter,
1119             func_use_fallback_tensor_)))
1120       return failure();
1121 
1122     auto new_op = rewriter.create<tfrt::compiler::CondOp>(
1123         op.getLoc(), result_types, bool_cond, then_branch, else_branch,
1124         new_operands);
1125 
1126     // The first result is a !tfrt.chain.
1127     rewriter.replaceOp(op, new_op.getResults().drop_front(1));
1128 
1129     if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1130       // Register the converted op so that it can be retrieved by successors.
1131       // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1132       // result is a chain.
1133       corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
1134     }
1135 
1136     return success();
1137   }
1138 
1139  private:
1140   mlir::TypeConverter &type_converter_;
1141   CoreRTConverter &corert_converter_;
1142   bool func_use_fallback_tensor_;
1143 };
1144 
1145 // Convert TF WhileOp to tfrt.while. tfrt.while use a boolean condition and has
1146 // slightly different semantics from tf.While for performance and generality.
1147 // The pseudo code of tfrt.while is as follows:
1148 //
1149 //  while(cond) {
1150 //    outputs, cond = body(inputs)
1151 //    inputs = outputs
1152 //  }
1153 //  return outputs
1154 //
1155 // So we need to insert extra convertion kernels and merge functions when
1156 // lowering tf.While to tfrt.while.
1157 //
1158 //  %result = tf.While(%arg) {cond = @original_cond_fn, body =
1159 //  @original_body_fn}
1160 //
1161 // is converted to
1162 //
1163 //  func @new_pred_fn(%ch, %arg) {
1164 //    %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1165 //    %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1166 //    tfrt.return %ch0, %cond_bool
1167 //  }
1168 //
1169 //  func @new_while_body(%ch, %arg) {
1170 //    %ch0, %result = tfrt.call @original_body_fn(%ch, %arg)
1171 //    %ch1, %cond_bool = tfrt.call @new_pred_fn(%ch0, %result)
1172 //    tfrt.return %ch1, %result, %cond_bool
1173 //  }
1174 //
1175 //  %ch0, %first_iter_cond = tfrt.call @new_pred_fn(%ch, %arg)
1176 //  %ch1, %result = tfrt.while %first_iter_cond @new_while_body(%ch0, %arg)
1177 //
1178 class TFRTWhileOpConversion
1179     : public mlir::OpConversionPattern<mlir::TF::WhileOp> {
1180  public:
TFRTWhileOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool func_use_fallback_tensor,bool enable_while_parallel_iterations)1181   TFRTWhileOpConversion(mlir::MLIRContext *context,
1182                         mlir::TypeConverter *type_converter,
1183                         CoreRTConverter *corert_converter,
1184                         mlir::SymbolTable *symbol_table,
1185                         const tfrt_compiler::TensorArraySideEffectAnalysis
1186                             *tensor_array_side_effect_analysis,
1187                         bool func_use_fallback_tensor,
1188                         bool enable_while_parallel_iterations)
1189       : mlir::OpConversionPattern<TF::WhileOp>(context),
1190         type_converter_(*type_converter),
1191         corert_converter_(*corert_converter),
1192         symbol_table_(*symbol_table),
1193         tensor_array_side_effect_analysis_(*tensor_array_side_effect_analysis),
1194         func_use_fallback_tensor_(func_use_fallback_tensor),
1195         enable_while_parallel_iterations_(enable_while_parallel_iterations) {}
1196 
matchAndRewrite(mlir::TF::WhileOp op,OpAdaptor adaptor,mlir::ConversionPatternRewriter & rewriter) const1197   mlir::LogicalResult matchAndRewrite(
1198       mlir::TF::WhileOp op, OpAdaptor adaptor,
1199       mlir::ConversionPatternRewriter &rewriter) const override {
1200     mlir::FlatSymbolRefAttr cond_fn = op.condAttr();
1201     mlir::FlatSymbolRefAttr body_fn = op.bodyAttr();
1202 
1203     llvm::SmallVector<Type, 4> while_arg_result_types;
1204     // Insert a chain for side effects as the first argument/result.
1205     while_arg_result_types.push_back(
1206         rewriter.getType<tfrt::compiler::ChainType>());
1207     for (mlir::Type type : op.getOperation()->getResultTypes()) {
1208       if (failed(type_converter_.convertType(type, while_arg_result_types)))
1209         return failure();
1210     }
1211 
1212     // Convert the operands to either fallback tensor or corert tensors as
1213     // specified in the option.
1214     llvm::SmallVector<mlir::Value, 4> new_operands;
1215     if (mlir::failed(ConvertFunctionCallOperands(op, adaptor.getOperands(),
1216                                                  &new_operands, rewriter,
1217                                                  func_use_fallback_tensor_)))
1218       return failure();
1219 
1220     // Create the predicate function that calls the original cond function and
1221     // in addition convert the result to a boolean value.
1222     mlir::func::FuncOp pred_fn =
1223         GetPredicateFunction(op, cond_fn, while_arg_result_types, rewriter);
1224     if (!pred_fn) return failure();
1225 
1226     auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
1227     auto out_chain = in_chain;
1228 
1229     bool has_at_most_tensor_array_effect = HasAtMostTensorArrayEffect(op);
1230 
1231     // Prepare the arguments to call the pred function for the first iteration.
1232     llvm::SmallVector<mlir::Value, 4> pred_args;
1233     pred_args.push_back(
1234         has_at_most_tensor_array_effect ? GetFunctionInputChain(op) : in_chain);
1235     pred_args.append(new_operands.begin(), new_operands.end());
1236 
1237     // Insert a call op to call the pred function for the first iteration.
1238     auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1239         op.getLoc(), pred_fn.getFunctionType().getResults(),
1240         pred_fn.getSymName(), pred_args);
1241 
1242     auto pred_chain = call_pred_fn.getResult(0);
1243     auto first_iteration_bool_cond = call_pred_fn.getResult(1);
1244 
1245     mlir::func::FuncOp new_body_fn = GetWhileBodyFunction(
1246         op, body_fn, pred_fn, while_arg_result_types, rewriter);
1247 
1248     // Use the pred chain as the chain to the while body. The rest args should
1249     // be the same as the pred_args.
1250     auto &while_args = pred_args;
1251     while_args[0] = pred_chain;
1252 
1253     int64_t parallel_iterations =
1254         enable_while_parallel_iterations_ ? op.parallel_iterations() : 1;
1255 
1256     auto new_op = rewriter.create<tfrt::compiler::WhileOp>(
1257         op.getLoc(), while_arg_result_types, first_iteration_bool_cond,
1258         while_args, new_body_fn.getSymName(), parallel_iterations);
1259 
1260     rewriter.replaceOp(op, new_op.getResults().drop_front());
1261 
1262     if (!has_at_most_tensor_array_effect) out_chain = new_op.getResult(0);
1263 
1264     if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1265       // Register the converted op so that it can be retrieved by successors.
1266       // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1267       // result is a chain.
1268       corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
1269     }
1270     return success();
1271   }
1272 
1273  private:
HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const1274   bool HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const {
1275     return tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1276                op.cond_function()) &&
1277            tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1278                op.body_function());
1279   }
1280 
1281   mlir::func::FuncOp GetPredicateFunction(
1282       mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1283       mlir::TypeRange arg_types,
1284       mlir::ConversionPatternRewriter &rewriter) const;
1285 
1286   mlir::func::FuncOp GetWhileBodyFunction(
1287       mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr body_fn,
1288       mlir::func::FuncOp pred_fn, mlir::TypeRange arg_types,
1289       mlir::ConversionPatternRewriter &rewriter) const;
1290 
1291   mlir::TypeConverter &type_converter_;
1292   CoreRTConverter &corert_converter_;
1293   mlir::SymbolTable &symbol_table_;
1294   const tfrt_compiler::TensorArraySideEffectAnalysis
1295       &tensor_array_side_effect_analysis_;
1296   bool func_use_fallback_tensor_;
1297   bool enable_while_parallel_iterations_;
1298 };
1299 
1300 // Create the pred function that contains a call to the original cond function
1301 // and a predicate kernel that converts the cond tensor to a boolean value. eg.
1302 //
1303 // func @pred_fn(%ch, %arg) {
1304 //  %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1305 //  %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1306 //  return %ch0, %cond_bool
1307 // }
1308 //
GetPredicateFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr cond_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1309 mlir::func::FuncOp TFRTWhileOpConversion::GetPredicateFunction(
1310     mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1311     mlir::TypeRange arg_types,
1312     mlir::ConversionPatternRewriter &rewriter) const {
1313   std::string pred_fn_name = cond_fn.getValue().str() + "/tfrt_predicate";
1314 
1315   if (auto pred_fn = symbol_table_.lookup<mlir::func::FuncOp>(pred_fn_name)) {
1316     return pred_fn;
1317   }
1318 
1319   auto func_op = op->getParentOfType<mlir::func::FuncOp>();
1320 
1321   mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1322   rewriter.setInsertionPointAfter(func_op);
1323 
1324   std::array<mlir::Type, 2> pred_result_types = {
1325       rewriter.getType<tfrt::compiler::ChainType>(), rewriter.getI1Type()};
1326 
1327   auto func_type = rewriter.getFunctionType(arg_types, pred_result_types);
1328 
1329   auto pred_fn =
1330       rewriter.create<mlir::func::FuncOp>(op.getLoc(), pred_fn_name, func_type);
1331 
1332   auto *block = pred_fn.addEntryBlock();
1333   rewriter.setInsertionPointToStart(block);
1334 
1335   // There are at least two arguments, with the first being !tfrt.chain and the
1336   // second being either !tfrt_fallback.tf_tensor or !corert.tensorhandle.
1337   // cond_fn must have two results. The first of them must also be !tfrt.chain
1338   // and the second must have the same tensor type as arguments. So we can just
1339   // use the first two arg types as the result types.
1340   assert(arg_types.size() >= 2);
1341   auto cond_result_types = arg_types.take_front(2);
1342 
1343   auto call_cond_fn = rewriter.create<tfrt::compiler::CallOp>(
1344       op.getLoc(), cond_result_types, cond_fn, block->getArguments());
1345 
1346   auto chain = call_cond_fn.getResult(0);
1347   auto cond = call_cond_fn.getResult(1);
1348 
1349   auto bool_cond = GetPredicate(op, cond, rewriter);
1350   if (!bool_cond) return {};
1351 
1352   llvm::SmallVector<mlir::Value, 2> results = {chain, bool_cond};
1353 
1354   rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), results);
1355 
1356   symbol_table_.insert(pred_fn);
1357 
1358   return pred_fn;
1359 }
1360 
1361 // Create the new while body function that contains a call to original while
1362 // body and then a call to the pred function. eg.
1363 //
1364 // func @new_while_body(%ch, %arg) {
1365 //   %ch0, %result = tfrt.call @original_body(%ch, %arg)
1366 //   %ch1, %cond_bool = tfrt.call @pred_function(%ch0, %arg)
1367 //   tfrt.return %ch1, %result, %cond_bool
1368 // }
1369 //
GetWhileBodyFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr original_body_fn,mlir::func::FuncOp pred_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1370 mlir::func::FuncOp TFRTWhileOpConversion::GetWhileBodyFunction(
1371     mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr original_body_fn,
1372     mlir::func::FuncOp pred_fn, mlir::TypeRange arg_types,
1373     mlir::ConversionPatternRewriter &rewriter) const {
1374   int64_t parallel_iterations =
1375       enable_while_parallel_iterations_ ? op.parallel_iterations() : 1;
1376 
1377   std::string body_fn_name = original_body_fn.getValue().str() + "/tfrt_body_" +
1378                              absl::StrCat(parallel_iterations);
1379 
1380   if (auto body_fn = symbol_table_.lookup<mlir::func::FuncOp>(body_fn_name)) {
1381     return body_fn;
1382   }
1383 
1384   auto func_op = op->getParentOfType<mlir::func::FuncOp>();
1385 
1386   mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1387   rewriter.setInsertionPointAfter(func_op);
1388 
1389   auto while_result_types = arg_types;
1390 
1391   llvm::SmallVector<mlir::Type, 4> body_result_types(arg_types.begin(),
1392                                                      arg_types.end());
1393 
1394   // The last result of the while body function is the boolean condition.
1395   body_result_types.push_back(rewriter.getI1Type());
1396   auto func_type = rewriter.getFunctionType(arg_types, body_result_types);
1397   auto body_fn =
1398       rewriter.create<mlir::func::FuncOp>(op.getLoc(), body_fn_name, func_type);
1399   if (parallel_iterations > 1) {
1400     // Disable stream merging by setting cost threshold to 1. The key to
1401     // parallelize while iterations is to execute iteration index handling (e.g.
1402     // ++i; i < max_iterations) in parallel to loop bodies. Since iteration
1403     // index handling is usually cheap when this while loop is parallelizable,
1404     // we don't want to merge it with loop bodies into one stream of inline
1405     // execution. The quickest way to achieve this is to disable stream merging
1406     // for loop functions. The potential downside is stream merging won't be
1407     // applied to other part of the loop body, so there might be excessive
1408     // threading overhead.
1409     //
1410     // TODO(chky): Consider a better way of parallizing while ops, so that we
1411     // can have stream merging applied within the loop body. One option is to
1412     // perform compiler transformation to extract iteration index handling logic
1413     // out of the loop body and convert it to a parallel_map-like op.
1414     body_fn->setAttr("tfrt.cost_threshold", rewriter.getI64IntegerAttr(1));
1415   }
1416 
1417   auto *block = body_fn.addEntryBlock();
1418   rewriter.setInsertionPointToStart(block);
1419 
1420   // Insert a call to the original body function.
1421   auto call_original_body_fn = rewriter.create<tfrt::compiler::CallOp>(
1422       op.getLoc(), while_result_types, original_body_fn, block->getArguments());
1423 
1424   // Insert a call to the pred function, which contains a call to the original
1425   // cond function and the predicate kernel that converts the tensor to boolean
1426   // value.
1427   auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1428       op.getLoc(), pred_fn.getFunctionType().getResults(), pred_fn.getSymName(),
1429       call_original_body_fn.getResults());
1430 
1431   auto pred_chain = call_pred_fn.getResult(0);
1432 
1433   llvm::SmallVector<mlir::Value, 4> body_results;
1434   // The first result should be the chain returned from the pred function.
1435   body_results.push_back(pred_chain);
1436   // Then comes the results from the orignal while body excluding the first
1437   // chain (which is replaced by the `pred_chain`).
1438   for (auto res : llvm::drop_begin(call_original_body_fn.getResults())) {
1439     body_results.push_back(res);
1440   }
1441   // The last result should be the boolean value converted from the condition.
1442   auto bool_cond = call_pred_fn.getResult(1);
1443   body_results.push_back(bool_cond);
1444 
1445   rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), body_results);
1446 
1447   symbol_table_.insert(body_fn);
1448 
1449   return body_fn;
1450 }
1451 
1452 // TODO(ezhulenev): tf_device.cluster operations after auto-fusion should
1453 // have the correct device assigned based on the fused operations. We should
1454 // use this device to convert operands and results from/to corert handles.
1455 // For now it is safe to assume that it is "CPU" because we do not support
1456 // any other devices and do not support distributed models.
1457 constexpr char kJitRtDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
1458 
1459 // Convert jitrt.call operations to the tf_jitrt.fallback.execute operation.
1460 class JitRtCallToJitRtCompileAndExecuteConversion
1461     : public OpConversionPattern<tfrt::jitrt::CallOp> {
1462  public:
JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext * context)1463   explicit JitRtCallToJitRtCompileAndExecuteConversion(MLIRContext *context)
1464       : OpConversionPattern<tfrt::jitrt::CallOp>(context) {}
1465 
matchAndRewrite(tfrt::jitrt::CallOp call,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1466   LogicalResult matchAndRewrite(
1467       tfrt::jitrt::CallOp call, OpAdaptor adaptor,
1468       ConversionPatternRewriter &rewriter) const override {
1469     // Convert operands to fallback tensors.
1470     llvm::SmallVector<Value, 4> fallback_operands;
1471     if (failed(tfrt_compiler::ConvertFallbackOperands(
1472             call, kJitRtDevice, adaptor.getOperands(), &fallback_operands,
1473             rewriter)))
1474       return rewriter.notifyMatchFailure(call, "failed to convert operand");
1475 
1476     // tf_jitrt.fallback.execute always produces fallback tensors.
1477     llvm::SmallVector<Type, 4> result_types(
1478         call->getNumResults(),
1479         rewriter.getType<tfrt::fallback::TFTensorType>());
1480 
1481     // Replace jitrt.call operation with a tf_jitrt.fallback.execute operation.
1482     rewriter.replaceOpWithNewOp<tf_jitrt::FallbackExecuteOp>(
1483         call, result_types, call.callee(), fallback_operands, kJitRtDevice);
1484 
1485     return success();
1486   }
1487 };
1488 
1489 // Helper function for specifying legal dialects for conversion to CoreRT.
SetUpTFToTFRTConversionLegality(mlir::ConversionTarget * target,mlir::TypeConverter * func_type_converter,mlir::Type chain_type)1490 void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target,
1491                                      mlir::TypeConverter *func_type_converter,
1492                                      mlir::Type chain_type) {
1493   target->addLegalDialect<tfrt::corert::CoreRTDialect>();
1494   target->addLegalDialect<tfrt::fallback_async::FallbackAsyncDialect>();
1495   target->addLegalDialect<tfrt::compiler::TFRTDialect>();
1496   target->addLegalDialect<tfrt::dist::DistributedDialect>();
1497   target->addLegalDialect<tfrt::test::TestDialect>();
1498   target->addLegalDialect<tf_jitrt::JitRuntimeDialect>();
1499   target->addIllegalDialect<TF::TensorFlowDialect>();
1500   target->addIllegalDialect<tf_device::TensorFlowDeviceDialect>();
1501   target->addIllegalDialect<tfrt::jitrt::JitRuntimeDialect>();
1502   target->addDynamicallyLegalOp<mlir::func::FuncOp>([func_type_converter,
1503                                                      chain_type](
1504                                                         func::FuncOp op) {
1505     auto func_type = op.getFunctionType();
1506     if (func_type.getNumInputs() == 0 || func_type.getInput(0) != chain_type)
1507       return false;
1508     if (func_type.getNumResults() == 0 || func_type.getResult(0) != chain_type)
1509       return false;
1510 
1511     return func_type_converter->isSignatureLegal(op.getFunctionType()) &&
1512            func_type_converter->isLegal(&op.getBody());
1513   });
1514 }
1515 
1516 // Helper function for inserting TFRT JitRt dialect conversions.
PopulateJitRtConversionPatterns(MLIRContext * context,RewritePatternSet * patterns,CoreRTConverter * corert_converter)1517 void PopulateJitRtConversionPatterns(MLIRContext *context,
1518                                      RewritePatternSet *patterns,
1519                                      CoreRTConverter *corert_converter) {
1520   // Lower jitrt.call to the pair of compile and execute operations.
1521   patterns->add<JitRtCallToJitRtCompileAndExecuteConversion>(context);
1522 }
1523 
1524 // Helper function for inserting TF dialect to TFRT dialect op conversion
1525 // patterns.
PopulateTFToTFRTConversionPatterns(mlir::MLIRContext * context,mlir::RewritePatternSet * patterns,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::CostAnalysis * cost_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool enable_native_ops,bool func_use_fallback_tensor,bool enable_while_parallel_iterations,bool tpu_lower_to_fallback,bool target_tpurt)1526 void PopulateTFToTFRTConversionPatterns(
1527     mlir::MLIRContext *context, mlir::RewritePatternSet *patterns,
1528     CoreRTConverter *corert_converter,
1529     tfrt_compiler::FallbackConverter *fallback_converter,
1530     mlir::SymbolTable *symbol_table,
1531     const tfrt_compiler::CostAnalysis *cost_analysis,
1532     const tfrt_compiler::TensorArraySideEffectAnalysis
1533         *tensor_array_side_effect_analysis,
1534     bool enable_native_ops, bool func_use_fallback_tensor,
1535     bool enable_while_parallel_iterations, bool tpu_lower_to_fallback,
1536     bool target_tpurt) {
1537   // By default, we lower all TF ops to fallback ops.
1538   patterns->add<FallbackExecuteOpConversion>(
1539       context, corert_converter, fallback_converter, cost_analysis,
1540       tpu_lower_to_fallback, target_tpurt);
1541   patterns->add<FallbackConstOpConversion, FallbackSetResourceOp,
1542                 FallbackGetResourceOp>(context, corert_converter);
1543 
1544   // For control flow ops, we handle them according to the option.
1545   mlir::TypeConverter *func_type_converter;
1546   if (func_use_fallback_tensor) {
1547     func_type_converter = fallback_converter;
1548   } else {
1549     func_type_converter = corert_converter;
1550   }
1551   patterns->add<TFRTFuncOpSignatureConversion>(context, func_type_converter);
1552   patterns->add<TFRTReturnOpConversion>(context, corert_converter,
1553                                         func_use_fallback_tensor);
1554   patterns->add<TFRTWhileOpConversion>(
1555       context, func_type_converter, corert_converter, symbol_table,
1556       tensor_array_side_effect_analysis, func_use_fallback_tensor,
1557       enable_while_parallel_iterations);
1558   patterns->add<TFRTCallOpConversion<mlir::TF::StatefulPartitionedCallOp>,
1559                 TFRTCallOpConversion<mlir::TF::PartitionedCallOp>,
1560                 TFRTCallOpConversion<mlir::TF::LegacyCallOp>,
1561                 TFRTCaseOpConversion, TFRTCondOpConversion>(
1562       context, func_type_converter, corert_converter, func_use_fallback_tensor);
1563 
1564   // For tf.BatchFunction, we need a special fallback op to batch a BEF
1565   // function.
1566   patterns->add<FallbackBatchFunctionOpConversion>(context, corert_converter);
1567 
1568   // Below patterns are preferred over fallback lowering as we want to use
1569   // CoreRT interface for native kernels. This is only temporary and it will
1570   // refactored to use SSOT interface.
1571 
1572   // Here we use specialized patterns for tf.Const on CPU as it is incorrect to
1573   // use ExecuteOp pattern to convert string tensor attribute.
1574   patterns->add<CoreRTConstStringTensorOpConversion,
1575                 CoreRTConstDenseTensorOpConversion>(context, corert_converter);
1576 
1577   if (enable_native_ops) {
1578     // Below TF operations will be converted to use corert.executeop, which will
1579     // invoke TFRT native op if implemented.
1580     // TODO(b/187942369): Pattern registration for TF operations is not
1581     // sustainable currently. We need to figure out a plan.
1582     patterns->add<CoreRTExecuteOpConversion<TF::AddV2Op>,
1583                   // TODO(chky): Move the ReadVariableOp + Identity pattern
1584                   // to optimizer.
1585                   // CoreRTExecuteOpConversion<TF::IdentityOp>,
1586                   CoreRTExecuteOpConversion<TF::MulOp>,
1587                   CoreRTExecuteOpConversion<TF::BiasAddOp>,
1588                   CoreRTExecuteOpConversion<TF::Conv2DOp>,
1589                   CoreRTExecuteOpConversion<TF::ConcatV2Op>,
1590                   CoreRTExecuteOpConversion<TF::CastOp>,
1591                   CoreRTExecuteOpConversion<TF::ExpandDimsOp>,
1592                   CoreRTExecuteOpConversion<TF::TransposeOp>,
1593                   CoreRTExecuteOpConversion<TF::FusedBatchNormV3Op>,
1594                   CoreRTExecuteOpConversion<TF::_FusedBatchNormExOp>,
1595                   CoreRTExecuteOpConversion<TF::LogOp>,
1596                   CoreRTExecuteOpConversion<TF::Log1pOp>,
1597                   CoreRTExecuteOpConversion<TF::LogSoftmaxOp>,
1598                   CoreRTExecuteOpConversion<TF::MatMulOp>,
1599                   CoreRTExecuteOpConversion<TF::_FusedMatMulOp>,
1600                   CoreRTExecuteOpConversion<TF::MaxPoolOp>,
1601                   CoreRTExecuteOpConversion<TF::MeanOp>,
1602                   CoreRTExecuteOpConversion<TF::MulOp>,
1603                   CoreRTExecuteOpConversion<TF::PadOp>,
1604                   CoreRTExecuteOpConversion<TF::RealDivOp>,
1605                   CoreRTExecuteOpConversion<TF::ReluOp>,
1606                   CoreRTExecuteOpConversion<TF::ReshapeOp>,
1607                   CoreRTExecuteOpConversion<TF::RsqrtOp>,
1608                   CoreRTExecuteOpConversion<TF::SoftmaxOp>,
1609                   CoreRTExecuteOpConversion<TF::ShapeOp>,
1610                   CoreRTExecuteOpConversion<TF::SigmoidOp>,
1611                   CoreRTExecuteOpConversion<TF::SubOp>,
1612                   CoreRTExecuteOpConversion<TF::TileOp>,
1613                   CoreRTExecuteOpConversion<TF::TanhOp>,
1614                   CoreRTExecuteOpConversion<TF::ZerosLikeOp>>(context,
1615                                                               corert_converter);
1616   }
1617 }
1618 
1619 // Lower TF dialect MLIR to TFRT dialect.
1620 class TfToTfrtConversionPass
1621     : public mlir::PassWrapper<TfToTfrtConversionPass,
1622                                mlir::OperationPass<mlir::ModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const1623   void getDependentDialects(mlir::DialectRegistry &registry) const override {
1624     getDependentConversionDialects(registry);
1625 
1626     if (target_tpurt_) RegisterTPUDialects(&registry);
1627   }
1628 
getArgument() const1629   llvm::StringRef getArgument() const final { return "tf-to-tfrt"; }
getDescription() const1630   llvm::StringRef getDescription() const final {
1631     return "Convert Tensorflow dialect (generated from tf.function) to TFRT "
1632            "dialect.";
1633   }
1634 
1635  public:
1636   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TfToTfrtConversionPass)
1637 
1638   TfToTfrtConversionPass() = default;
TfToTfrtConversionPass(const TfrtPipelineOptions & options)1639   explicit TfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1640     target_tpurt_ = options.target_tpurt;
1641     enable_native_ops_ = options.enable_native_ops;
1642     tpu_use_core_selector_ = options.tpu_use_core_selector;
1643     tpu_use_bundled_transfer_ = options.tpu_use_bundled_transfer;
1644     tpu_lower_to_fallback_ = options.tpu_lower_to_fallback;
1645     tpu_transfer_result_to_host_ = options.tpu_transfer_result_to_host;
1646     use_tpu_host_allocator_for_inputs_ =
1647         options.use_tpu_host_allocator_for_inputs;
1648     cost_threshold_ = options.cost_threshold;
1649     upper_cost_threshold_ = options.upper_cost_threshold;
1650     merge_inter_dependent_streams_ = options.merge_inter_dependent_streams;
1651     func_use_fallback_tensor_ = options.func_use_fallback_tensor;
1652     enable_while_parallel_iterations_ =
1653         options.enable_while_parallel_iterations;
1654   }
TfToTfrtConversionPass(const TfToTfrtConversionPass &)1655   TfToTfrtConversionPass(const TfToTfrtConversionPass &) {}
1656 
runOnFunction(mlir::func::FuncOp func,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis & tensor_array_side_effect_analysis,tfrt_compiler::FallbackConverter & fallback_converter,mlir::SymbolTable & symbol_table)1657   mlir::LogicalResult runOnFunction(
1658       mlir::func::FuncOp func,
1659       const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
1660       const tfrt_compiler::TensorArraySideEffectAnalysis
1661           &tensor_array_side_effect_analysis,
1662       tfrt_compiler::FallbackConverter &fallback_converter,
1663       mlir::SymbolTable &symbol_table) {
1664     auto &context = getContext();
1665     mlir::ConversionTarget target(context);
1666     mlir::RewritePatternSet patterns(&getContext());
1667     CoreRTConverter corert_converter(&context, &side_effect_analysis);
1668     tfrt_compiler::CostAnalysis cost_analysis(func);
1669 
1670     if (target_tpurt_)
1671       AddTPUTargetDialectAndPatterns(
1672           &target, &patterns, &context, &corert_converter, &fallback_converter,
1673           TfrtTpuExecuteOpConversionOptions{
1674               tpu_use_core_selector_, tpu_use_bundled_transfer_,
1675               tpu_transfer_result_to_host_, use_tpu_host_allocator_for_inputs_},
1676           tpu_lower_to_fallback_);
1677 
1678     mlir::TypeConverter *func_type_converter;
1679     if (func_use_fallback_tensor_) {
1680       func_type_converter = &fallback_converter;
1681     } else {
1682       func_type_converter = &corert_converter;
1683     }
1684     SetUpTFToTFRTConversionLegality(&target, func_type_converter,
1685                                     corert_converter.chain_type());
1686     PopulateJitRtConversionPatterns(&context, &patterns, &corert_converter);
1687 
1688     PopulateTFToTFRTConversionPatterns(
1689         &context, &patterns, &corert_converter, &fallback_converter,
1690         &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis,
1691         enable_native_ops_, func_use_fallback_tensor_,
1692         enable_while_parallel_iterations_, tpu_lower_to_fallback_,
1693         target_tpurt_);
1694 
1695     return mlir::applyPartialConversion(func, target, std::move(patterns));
1696   }
1697 
runOnOperation()1698   void runOnOperation() override {
1699     auto module = getOperation();
1700     const auto &side_effect_analysis =
1701         getAnalysis<mlir::TF::SideEffectAnalysis>();
1702 
1703     tfrt_compiler::TensorArraySideEffectAnalysis
1704         tensor_array_side_effect_analysis(module);
1705     tfrt_compiler::FallbackConverter fallback_converter(&getContext());
1706 
1707     mlir::SymbolTable symbol_table(module);
1708 
1709     auto func_op_range = module.getOps<mlir::func::FuncOp>();
1710     llvm::SmallVector<mlir::func::FuncOp, 4> func_ops(func_op_range.begin(),
1711                                                       func_op_range.end());
1712     for (auto func : func_ops) {
1713       if (!func.isExternal()) {
1714         if (mlir::failed(runOnFunction(
1715                 func, side_effect_analysis.GetAnalysisForFunc(func),
1716                 tensor_array_side_effect_analysis, fallback_converter,
1717                 symbol_table))) {
1718           signalPassFailure();
1719           return;
1720         }
1721 
1722         ChainDanglingValuesinFunction(func);
1723       }
1724     }
1725 
1726     CreateFallbackInitializationFunction(module, fallback_converter);
1727 
1728     // Set cost threshold as a module attribute. It will be used later in the
1729     // runtime to decide whether certain execution is cheap enough to be inline
1730     // executed.
1731     mlir::Builder builder(module);
1732     module->setAttr("tfrt.cost_threshold",
1733                     builder.getI64IntegerAttr(cost_threshold_));
1734     module->setAttr("tfrt.upper_cost_threshold",
1735                     builder.getI64IntegerAttr(upper_cost_threshold_));
1736     module->setAttr("tfrt.merge_inter_dependent_streams",
1737                     builder.getBoolAttr(merge_inter_dependent_streams_));
1738   }
1739 
1740  private:
1741   // Chain all dangling values (ie. values with no users) together and merge it
1742   // with the first returned chain. This merged chain can be used to signal the
1743   // completion of all execution including side-effets.
ChainDanglingValuesinFunction(mlir::func::FuncOp func_op)1744   void ChainDanglingValuesinFunction(mlir::func::FuncOp func_op) {
1745     auto &block = func_op.front();
1746 
1747     llvm::SmallVector<mlir::Value, 2> dangling_values;
1748 
1749     // Find dangling function arguments.
1750     for (auto arg : block.getArguments()) {
1751       if (arg.use_empty()) {
1752         dangling_values.push_back(arg);
1753       }
1754     }
1755 
1756     // Find dangling values produced by ops in the function.
1757     for (auto &op : block) {
1758       for (mlir::Value result : op.getResults()) {
1759         if (result.use_empty()) {
1760           dangling_values.push_back(result);
1761         }
1762       }
1763     }
1764 
1765     if (dangling_values.empty()) return;
1766 
1767     // Merge these dangling values with the first returned chain.
1768     auto return_op =
1769         llvm::cast<tfrt::compiler::ReturnOp>(block.getTerminator());
1770     auto chain = return_op->getOperand(0);
1771     assert(chain.getType().isa<tfrt::compiler::ChainType>());
1772     dangling_values.push_back(chain);
1773 
1774     mlir::OpBuilder builder(return_op);
1775 
1776     auto new_chain = builder.create<tfrt::compiler::MergeChainsOp>(
1777         return_op->getLoc(), builder.getType<tfrt::compiler::ChainType>(),
1778         dangling_values);
1779 
1780     return_op->setOperand(0, new_chain);
1781   }
1782 
CreateFallbackInitializationFunction(mlir::ModuleOp module,tfrt_compiler::FallbackConverter & fallback_converter)1783   void CreateFallbackInitializationFunction(
1784       mlir::ModuleOp module,
1785       tfrt_compiler::FallbackConverter &fallback_converter) {
1786     mlir::OpBuilder builder(&module.getBodyRegion());
1787 
1788     auto chain_type = builder.getType<tfrt::compiler::ChainType>();
1789 
1790     auto func_op = builder.create<mlir::func::FuncOp>(
1791         module.getLoc(), "_tfrt_fallback_init",
1792         mlir::FunctionType::get(module.getContext(), /*inputs=*/chain_type,
1793                                 /*results=*/chain_type));
1794 
1795     auto *block = func_op.addEntryBlock();
1796     builder.setInsertionPointToStart(block);
1797 
1798     mlir::Value chain_value = block->getArgument(0);
1799 
1800     // Create operations for all fallback kernels in the module.
1801     for (auto *op : fallback_converter.GetFallbackOps()) {
1802       auto create_op = builder.create<tfrt::fallback_async::CreateOp>(
1803           func_op.getLoc(), chain_type, chain_value);
1804 
1805       create_op->setAttrs(op->getAttrs());
1806       create_op->setAttr("num_args", builder.getI64IntegerAttr(GetNumArgs(op)));
1807 
1808       chain_value = create_op;
1809     }
1810 
1811     // Pre-compile all JIT compiled kernels found in the module.
1812     llvm::SmallVector<Value> compiled;
1813 
1814     // A set SymbolRef attributes referencing compiled kernels.
1815     llvm::DenseSet<mlir::Attribute> kernels;
1816 
1817     // Compile all kernels in parallell.
1818     module.walk([&](tf_jitrt::FallbackExecuteOp execute) {
1819       // Do not compiled the same kernel multiple times.
1820       if (kernels.contains(execute.kernel())) return;
1821 
1822       auto compile = builder.create<tf_jitrt::FallbackCompileOp>(
1823           execute.getLoc(), chain_type, execute.kernel(), execute.device());
1824       compiled.push_back(compile.getResult());
1825       kernels.insert(compile.kernel());
1826     });
1827 
1828     // Wait for the compilation completion before returning from init function.
1829     if (!compiled.empty()) {
1830       // Do not forget to wait for the fallback kernels initialization.
1831       compiled.insert(compiled.begin(), chain_value);
1832       chain_value = builder.create<tfrt::compiler::MergeChainsOp>(
1833           func_op.getLoc(), chain_type, compiled);
1834     }
1835 
1836     builder.create<tfrt::compiler::ReturnOp>(func_op.getLoc(), chain_value);
1837   }
1838 
GetNumArgs(mlir::Operation * fallback_op)1839   int64_t GetNumArgs(mlir::Operation *fallback_op) {
1840     if (auto execute_op =
1841             llvm::dyn_cast<tfrt::fallback_async::ExecuteOp>(fallback_op)) {
1842       return execute_op.operands().size();
1843     } else if (auto execute_op_seq =
1844                    llvm::dyn_cast<tfrt::fallback_async::ExecuteOpSeq>(
1845                        fallback_op)) {
1846       return execute_op_seq.operands().size();
1847     } else if (auto execute_op_allocator =
1848                    llvm::dyn_cast<tfrt::fallback_async::ExecuteOpWithAllocator>(
1849                        fallback_op)) {
1850       return execute_op_allocator.operands().size();
1851     } else if (auto execute_op_seq_allocator = llvm::dyn_cast<
1852                    tfrt::fallback_async::ExecuteOpSeqWithAllocator>(
1853                    fallback_op)) {
1854       return execute_op_seq_allocator.operands().size();
1855     }
1856     llvm_unreachable("invalid fallback op type");
1857   }
1858 
1859   Option<bool> target_tpurt_{*this, "target-tpurt",
1860                              llvm::cl::desc("Target TPURT dialect if true."),
1861                              llvm::cl::init(false)};
1862 
1863   Option<bool> enable_native_ops_{
1864       *this, "enable-native-ops",
1865       llvm::cl::desc(
1866           "If true, native ops will be used on an opt-in basis "
1867           "instead of fallback ops. If false, no native ops are used."),
1868       llvm::cl::init(true)};
1869 
1870   Option<bool> tpu_use_core_selector_{
1871       *this, "tpu-use-core-selector",
1872       llvm::cl::desc("If true, use ServingCoreSelector to pick TPU core. "
1873                      "Otherwise, use the assigned core."),
1874       llvm::cl::init(true)};
1875 
1876   Option<bool> tpu_use_bundled_transfer_{
1877       *this, "tpu-use-bundled-transfer",
1878       llvm::cl::desc("If true, use BundledTransferToTpuOp to transfer "
1879                      "variables and input tensors to TPU."),
1880       llvm::cl::init(true)};
1881 
1882   Option<bool> tpu_lower_to_fallback_{
1883       *this, "tpu-lower-to-fallback",
1884       llvm::cl::desc("If true, lower a TF op that's placed on TPU device "
1885                      "to be executed by tfrt_fallback.execute. Note that this "
1886                      "applies to ops other than TPU compile and execute ops."),
1887       llvm::cl::init(true)};
1888 
1889   // TODO(b/194081364): remove this option once we unify servo TPU serving
1890   // result transfer behavior.
1891   Option<bool> tpu_transfer_result_to_host_{
1892       *this, "tpu-transfer-result-to-host",
1893       llvm::cl::desc("If true, transfer the result of tpurt.execute from TPU "
1894                      "to host."),
1895       llvm::cl::init(true)};
1896 
1897   Option<bool> use_tpu_host_allocator_for_inputs_{
1898       *this, "use-tpu-host-allocator-for-inputs",
1899       llvm::cl::desc("If true, fallback executeops that produce inputs to tpu "
1900                      "program will use tpu host allocator."),
1901       llvm::cl::init(false)};
1902 
1903   Option<uint64_t> cost_threshold_{
1904       *this, "tfrt-cost-threshold",
1905       llvm::cl::desc(
1906           "The cost threshold to decide whether a sequence of operations is "
1907           "cheap, and then whether it can be executed inline."),
1908       llvm::cl::init(1)};
1909 
1910   Option<int64_t> upper_cost_threshold_{
1911       *this, "tfrt-upper-cost-threshold",
1912       llvm::cl::desc(
1913           "The threshold to limit the merging of dependent sequence."),
1914       llvm::cl::init(-1)};
1915 
1916   Option<bool> merge_inter_dependent_streams_{
1917       *this, "tfrt-merge-inter-dependent-streams",
1918       llvm::cl::desc("If true, streams with inter data depenedencies will be "
1919                      "preferred to be merged for inline execution."),
1920       llvm::cl::init(false)};
1921 
1922   Option<bool> func_use_fallback_tensor_{
1923       *this, "func-use-fallback-tensor",
1924       llvm::cl::desc(
1925           "If true, use TF tensor as input/output types in func (and other "
1926           "control flow) ops."),
1927       llvm::cl::init(false)};
1928 
1929   Option<bool> enable_while_parallel_iterations_{
1930       *this, "enable-while-parallel-iterations",
1931       llvm::cl::desc("If true, tf.While op will be parallelized. This is "
1932                      "currently experimental."),
1933       llvm::cl::init(false)};
1934 };
1935 
1936 // Assigns devices so that later passes can utilize device information.
1937 // Device assignement might have not been done by the upstream pipeline, or get
1938 // removed by previous passes. However, we assume most of the device assignment
1939 // has been done by the upstream pipeline, so we simply assign the default
1940 // device to unassigned ops. Specifically, we do assignment for ConstOp first to
1941 // place it on the same device as its user operation, instead of placing it on
1942 // the default device blindly.
1943 // TODO(b/221297389): Figure out a more robust way to handle dropped device
1944 // assignment.
AddTfDeviceAssignmentPasses(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)1945 void AddTfDeviceAssignmentPasses(mlir::OpPassManager &pm,
1946                                  const TfrtPipelineOptions &options) {
1947   pm.addPass(mlir::TF::CreateConstantOpDeviceAssignmentPass());
1948   pm.addNestedPass<mlir::func::FuncOp>(
1949       mlir::TF::CreateTFDeviceAssignmentByFuncAttrPass());
1950   pm.addNestedPass<mlir::func::FuncOp>(
1951       mlir::TF::CreateSimpleTFDeviceAssignmentPass(options.default_device));
1952 }
1953 
1954 }  // namespace
1955 
1956 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTfToTfrtConversionPass(const TfrtPipelineOptions & options)1957 CreateTfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1958   return std::make_unique<TfToTfrtConversionPass>(options);
1959 }
1960 
1961 // -------------------------------------------------------------------------- //
1962 // Outline tf_device.cluster operation regions into functions in the nested
1963 // modules and replaces all cluster operations with jitrt.call operations.
1964 // -------------------------------------------------------------------------- //
1965 
1966 class OutlineJitRtClustersPass
1967     : public PassWrapper<OutlineJitRtClustersPass, OperationPass<ModuleOp>> {
1968  public:
getArgument() const1969   llvm::StringRef getArgument() const final {
1970     return "tf-outline-jitrt-cluster";
1971   }
getDescription() const1972   llvm::StringRef getDescription() const final {
1973     return "Outlines `tf_device.cluster` operations into functions and "
1974            "replaces them with `jitrt.call` operations.";
1975   }
1976 
1977   void runOnOperation() override;
1978 
getDependentDialects(mlir::DialectRegistry & registry) const1979   void getDependentDialects(mlir::DialectRegistry &registry) const override {
1980     registry.insert<tfrt::jitrt::JitRuntimeDialect>();
1981   }
1982 
1983  private:
1984   struct CompiledModule {
1985     ModuleOp module;
1986     func::FuncOp entrypoint;
1987     llvm::SetVector<Value> operands;
1988   };
1989 
1990   // Creates a nested module with a single function that will be compiled into
1991   // the kernel at runtime.
1992   CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster,
1993                                       int64_t max_arg_size,
1994                                       SymbolTable *symbol_table);
1995 
1996   // Update compiled module entrypoint signature with inferred operands
1997   // constraints.
1998   LogicalResult SetEntrypointConstraints(CompiledModule &compiled);
1999 
2000   // Outlines cluster operation regions into compiled modules, and replaces
2001   // cluster operation with a jitrt.call operation.
2002   LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster,
2003                                  int64_t max_arg_size,
2004                                  SymbolTable *symbol_table);
2005 
2006   // Mapping from the outlined module string representation to the module itself
2007   // and an entrypoint function. Used to deduplicate identical modules during
2008   // the `tf_device.cluster` outlining.
2009   llvm::StringMap<std::pair<ModuleOp, func::FuncOp>> outlined_;
2010 };
2011 
2012 OutlineJitRtClustersPass::CompiledModule
CreateCompiledModule(tf_device::ClusterOp cluster,int64_t max_arg_size,SymbolTable * symbol_table)2013 OutlineJitRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster,
2014                                                int64_t max_arg_size,
2015                                                SymbolTable *symbol_table) {
2016   MLIRContext *ctx = cluster->getContext();
2017   Location loc = cluster.getLoc();
2018 
2019   // Create a module that will hold compiled function and async wrappers.
2020   // TODO(ezhulenev): Give better names to module and function.
2021   auto compiled_module = ModuleOp::create(loc, {"kernel"});
2022   compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx));
2023   compiled_module->setAttr(
2024       "tfrt.max-arg-size",
2025       IntegerAttr::get(IntegerType::get(ctx, 64), max_arg_size));
2026 
2027   SymbolTable compiled_module_symbol_table(compiled_module);
2028 
2029   // Find out the cluster arguments and their types.
2030   llvm::SetVector<Value> live_ins;
2031   getUsedValuesDefinedAbove(cluster.body(), cluster.body(), live_ins);
2032 
2033   llvm::SmallVector<Type, 4> operand_types;
2034   operand_types.reserve(live_ins.size());
2035   for (Value v : live_ins) operand_types.emplace_back(v.getType());
2036 
2037   // Create a function in the compiled module.
2038   auto compiled_func_type =
2039       FunctionType::get(ctx, operand_types, cluster->getResultTypes());
2040   auto compiled_func = func::FuncOp::create(loc, "compute", compiled_func_type);
2041   compiled_module_symbol_table.insert(compiled_func);
2042 
2043   // Replace uses of live-in values within cluster region with block arguments.
2044   Block *compiled_func_block = compiled_func.addEntryBlock();
2045   for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments()))
2046     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), cluster.body());
2047 
2048   // Move all operations in cluster into compiled_func's entry block.
2049   auto &cluster_body = cluster.GetBody().getOperations();
2050   compiled_func_block->getOperations().splice(
2051       compiled_func_block->end(), cluster_body, cluster_body.begin(),
2052       cluster_body.end());
2053 
2054   // Replace `tf_device.return` terminator with `func.return` in the function
2055   // body.
2056   auto device_return =
2057       cast<tf_device::ReturnOp>(compiled_func_block->getTerminator());
2058   OpBuilder builder(device_return.getOperation());
2059   builder.create<func::ReturnOp>(device_return.getLoc(),
2060                                  device_return.getOperands());
2061   device_return.erase();
2062 
2063   // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet,
2064   // replace module printing with a more principled solution when available.
2065   // Operations in the cluster can be in different order, however define the
2066   // identical Tensorflow programs, with current approach we'll not be able
2067   // to detect duplicates like this.
2068 
2069   // Remove location attribute attached to Tensorflow operations to be able to
2070   // deduplicate compiled clusters with the same set of operations.
2071   //
2072   // TODO(ezhulenev): Figure out how to propagate locations for error reporting,
2073   // right now JitRt will ignore them anyway.
2074   compiled_module.walk([](Operation *op) { op->removeAttr("_class"); });
2075 
2076   // Serialize prepared module to string.
2077   std::string serialized;
2078   llvm::raw_string_ostream os(serialized);
2079   compiled_module.print(os);
2080 
2081   // Try to find if identical module was already outlined.
2082   auto it = outlined_.find(serialized);
2083 
2084   // Return identical module that was already outlined earlier.
2085   if (it != outlined_.end()) {
2086     compiled_module.erase();  // erase identical module
2087     return {it->second.first, it->second.second, live_ins};
2088   }
2089 
2090   // Insert compiled module into the symbol table and assign it a unique name.
2091   symbol_table->insert(compiled_module);
2092 
2093   // Cache unique module.
2094   outlined_.insert({std::move(serialized), {compiled_module, compiled_func}});
2095 
2096   return {compiled_module, compiled_func, live_ins};
2097 }
2098 
SetEntrypointConstraints(CompiledModule & compiled)2099 LogicalResult OutlineJitRtClustersPass::SetEntrypointConstraints(
2100     CompiledModule &compiled) {
2101   func::FuncOp func = compiled.entrypoint;
2102 
2103   // Functions outlined from jitrt device clusters must have a single block.
2104   assert(func.getBody().getBlocks().size() == 1 && "expected single block");
2105 
2106   mlir::TFDevice::ClusteringPolicySet policies;
2107   populateTfJitRtConstraintsPolicies(policies);
2108 
2109   // Infer constraints on the values defined in the entrypoint function
2110   // (including function entry block arguments).
2111   mlir::TFDevice::ValuesConstraintSet constraints;
2112   if (failed(mlir::TFDevice::PropagateValuesConstraints(
2113           func.getBody(), policies, constraints, /*resolve=*/true)))
2114     return failure();
2115 
2116   // Annotate arguments with inferred constraints.
2117   for (unsigned i = 0; i < func.getNumArguments(); ++i) {
2118     if (auto constraint = constraints.GetConstraint(func.getArgument(i))) {
2119       auto constraint_name = mlir::StringAttr::get(
2120           &getContext(), llvm::formatv("{0}", *constraint).str());
2121       func.setArgAttr(i, "rt.constraint", constraint_name);
2122     }
2123   }
2124 
2125   return success();
2126 }
2127 
OutlineClusterOp(tf_device::ClusterOp cluster,int64_t max_arg_size,SymbolTable * symbol_table)2128 LogicalResult OutlineJitRtClustersPass::OutlineClusterOp(
2129     tf_device::ClusterOp cluster, int64_t max_arg_size,
2130     SymbolTable *symbol_table) {
2131   Location loc = cluster->getLoc();
2132   OpBuilder builder(cluster);
2133 
2134   CompiledModule compiled_module =
2135       CreateCompiledModule(cluster, max_arg_size, symbol_table);
2136   func::FuncOp compiled_func = compiled_module.entrypoint;
2137 
2138   // Add constraints to the entrypoint arguments.
2139   if (failed(SetEntrypointConstraints(compiled_module))) return failure();
2140 
2141   // Replace device cluster with a jitrt.call operation.
2142   auto module_name = *compiled_module.module.getSymName();
2143   auto func_name = compiled_func.getSymName();
2144   auto func_flat_ref =
2145       mlir::SymbolRefAttr::get(builder.getContext(), func_name);
2146   auto func_ref = mlir::SymbolRefAttr::get(builder.getContext(), module_name,
2147                                            {func_flat_ref});
2148 
2149   auto cluster_func_op = builder.create<tfrt::jitrt::CallOp>(
2150       loc, cluster.getResultTypes(), func_ref,
2151       compiled_module.operands.getArrayRef());
2152 
2153   cluster.replaceAllUsesWith(cluster_func_op);
2154   cluster.erase();
2155 
2156   return success();
2157 }
2158 
runOnOperation()2159 void OutlineJitRtClustersPass::runOnOperation() {
2160   ModuleOp module = getOperation();
2161   SymbolTable symbol_table(module);
2162 
2163   // Keep track of the maximum argument size for each function with tf_device
2164   // cluster operations in the function body. We need to pass it to the compiled
2165   // module to correctly compute its cost later.
2166   llvm::DenseMap<mlir::func::FuncOp, int64_t> max_arg_size_map;
2167 
2168   auto get_max_arg_size = [&](mlir::func::FuncOp func) -> int64_t {
2169     auto it = max_arg_size_map.find(func);
2170     if (it != max_arg_size_map.end()) return it->second;
2171     return max_arg_size_map[func] = tf_jitrt::GetMaxArgSize(func);
2172   };
2173 
2174   OpBuilder builder(module.getContext());
2175   auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult {
2176     // Ensure that cluster was formed for TFRT JIT compilation.
2177     auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2178     if (!policy || policy.getValue() != "tfrt.auto-fusion")
2179       return WalkResult::advance();
2180 
2181     // Get the maximum argument size of the parent function.
2182     mlir::func::FuncOp parent_func =
2183         cluster->getParentOfType<mlir::func::FuncOp>();
2184     int64_t max_arg_size = get_max_arg_size(parent_func);
2185 
2186     if (failed(OutlineClusterOp(cluster, max_arg_size, &symbol_table)))
2187       return WalkResult::interrupt();
2188     return WalkResult::advance();
2189   });
2190 
2191   if (result.wasInterrupted()) {
2192     module->emitError("Failed to outline tf_device.cluster operations");
2193     signalPassFailure();
2194   }
2195 }
2196 
CreateOutlineJitRtClustersPass()2197 static std::unique_ptr<Pass> CreateOutlineJitRtClustersPass() {
2198   return std::make_unique<OutlineJitRtClustersPass>();
2199 }
2200 
2201 // -------------------------------------------------------------------------- //
2202 
CreateTFExecutorToTFPipeline(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2203 void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm,
2204                                   const TfrtPipelineOptions &options) {
2205   // Due to b/191304670, functionalized while ops might not have the
2206   // shape_invariant attribute set correctly, which leads to failure in shape
2207   // inference. As a workaround, we conservatively (e.g., we place less
2208   // restrictions on tf.while which will avoid failures but lead to potentially
2209   // less exact shape inference) set the shape_invariant attribute in all
2210   // tf.While ops before performing shape inference.
2211   //
2212   // Note that this pass might not work well with TF XLA bridge, but this is
2213   // fine as TF XLA bridge is run before this pipeline. For CPU ops, less exact
2214   // shape inference may lead to fewer optimizations but it should be fine as it
2215   // is limited to while ops currently.
2216   //
2217   // TODO(b/191304670): Remove this pass once the shape_invariant attribute is
2218   // set correctly in the upstream.
2219   pm.addNestedPass<mlir::func::FuncOp>(
2220       tfrt_compiler::CreateSetShapeInvariantInWhileOps());
2221 
2222   // We pass the MLIR module through the TF standard pipeline, which for
2223   // instances does shape inference, canonicalization, inlining, etc.
2224   pm.addNestedPass<mlir::func::FuncOp>(
2225       mlir::tf_executor::CreateTFExecutorGraphPruningPass());
2226   pm.addNestedPass<mlir::func::FuncOp>(
2227       mlir::tf_executor::CreateTFExecutorIslandCoarseningPass());
2228 
2229   AddTfDeviceAssignmentPasses(pm, options);
2230 
2231   // Here we perform TFRT specific optimization before standard TF optimization,
2232   // as TFRT-specific optimization may create more opportunities.
2233   pm.addNestedPass<mlir::func::FuncOp>(
2234       tfrt_compiler::CreateOptimizeTfForTfrtPass());
2235   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2236   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
2237   pm.addPass(mlir::createInlinerPass());
2238   pm.addPass(mlir::createSymbolDCEPass());
2239   pm.addNestedPass<mlir::func::FuncOp>(mlir::TF::CreateTFOptimizePass());
2240   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
2241 
2242   AddTfDeviceAssignmentPasses(pm, options);
2243 
2244   // After the standard pass, we now have MLIR in TF dialect, and now we convert
2245   // reference variable to resource variables, which is besteffort.
2246   pm.addPass(CreateConvertReferenceVariableToResourceVariablePass());
2247 
2248   // Move the tf.Assert op to the end of the function, so that it does not
2249   // impose unnecessary control dependencies on other ops.
2250   pm.addPass(tfrt_compiler::CreateReorderTfAssertPass());
2251 
2252   // Optimze the side-effects of control flow ops by examining the ops in its
2253   // callees.
2254   pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass());
2255 
2256   // Remove tf.If ops' operands that are produced by tf.Const ops.
2257   pm.addPass(tfrt_compiler::CreateRemoveTfIfConstArgsPass());
2258 
2259   // Merge non-side-effecting tf.If ops if their operands are the same.
2260   pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass());
2261 
2262   // Deduplicate functions invoked by tf.BatchFunction with the same
2263   // shared_name
2264   pm.addPass(
2265       tfrt_compiler::CreateDeduplicateFunctionsInovkedByBatchFunctionPass());
2266 
2267   // Apply standard optimization after optimizing control flow ops.
2268   pm.addPass(mlir::createInlinerPass());
2269   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
2270 
2271   // TODO(b/187876545): An extra shape inference pass is added because it does
2272   // not work well with tf.Identity op that remove ref type. So we work around
2273   // by performing shape inference again after reference variable to resource
2274   // variable conversion. We should remove this after b/187876545 is fixed.
2275   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
2276 
2277   pm.addNestedPass<mlir::func::FuncOp>(
2278       mlir::TFDevice::CreateLaunchToDeviceAttributePass());
2279 
2280   // After all standard passes run layout optimization to assign optimal data
2281   // format for all layout sensitive operations.
2282   mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
2283   layout_optimization_options.force_data_format =
2284       options.force_data_format.getValue();
2285   // TODO(b/191304261): Folding transpose in ops is buggy in the layout
2286   // optimization pass. Disable it to avoid errors in b/191304261. This should
2287   // not affect CPU performance as it does not change the number of ops, nor
2288   // does it change the types of the ops.
2289   layout_optimization_options.skip_fold_transpose_in_ops = true;
2290   mlir::TF::CreateLayoutOptimizationPipeline(pm.nest<mlir::func::FuncOp>(),
2291                                              layout_optimization_options);
2292 
2293   // Run canonicalization pipeline to remove unused constants and bypassed
2294   // transpose operations left in the IR after layout optimization.
2295   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2296 
2297   // Decompose resource ops as resource variables will be converted to tensors
2298   // directly.
2299   if (options.decompose_resource_ops)
2300     pm.addNestedPass<mlir::func::FuncOp>(
2301         mlir::TFDevice::CreateDecomposeResourceOpsPass());
2302 
2303   AddTfDeviceAssignmentPasses(pm, options);
2304 
2305   pm.addNestedPass<mlir::func::FuncOp>(
2306       mlir::TF::CreateTensorDeviceCopyConversionPass());
2307 
2308   // Outline auto-fusion clusters into tf_device.cluster_operations and then
2309   // convert them to functions. We currently support only tfrt fallback tensors
2310   // as operands, so we disable these passes if we can have native ops after
2311   // lowering.
2312   if (!options.enable_native_ops) {
2313     pm.addNestedPass<mlir::func::FuncOp>(CreateTfJitRtClusteringPass(
2314         options.auto_fusion_oplist, options.auto_fusion_min_cluster_size));
2315 
2316     // Sink small constants into the outlined clusters to reduce the number of
2317     // arguments for each of the execute operations.
2318     auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster,
2319                                   mlir::ElementsAttr value) -> bool {
2320       // Ensure that cluster was formed for TFRT JIT compilation.
2321       auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2322       if (!policy || policy.getValue() != "tfrt.auto-fusion") return false;
2323 
2324       // Check that TF->JitRt compiler supports constant compilation.
2325       return mlir::succeeded(IsCompilableConstant(value));
2326     };
2327 
2328     pm.addNestedPass<mlir::func::FuncOp>(
2329         mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const));
2330 
2331     // Outline formed JIT compiled device clusters into function.
2332     pm.addPass(CreateOutlineJitRtClustersPass());
2333   }
2334 
2335   // Rewriter operation sequences to device specific fusions.
2336   DeviceNameUtils::ParsedName parsed_name;
2337 
2338   // Ignore error.
2339   bool success =
2340       DeviceNameUtils::ParseFullName(options.default_device, &parsed_name);
2341   assert(success && "default device is invalid");
2342   (void)success;
2343 
2344   if (parsed_name.has_type && parsed_name.type == DEVICE_GPU)
2345     pm.addNestedPass<mlir::func::FuncOp>(mlir::TF::CreateGpuOpFusionPass());
2346 
2347   if (parsed_name.has_type && parsed_name.type == DEVICE_CPU)
2348     pm.addNestedPass<mlir::func::FuncOp>(
2349         mlir::TF::CreateFusedKernelMatcherPass());
2350 
2351   if (options.tpu_fuse_ops) {
2352     pm.addNestedPass<mlir::func::FuncOp>(
2353         tfrt_compiler::CreateFuseTpuCompileAndExecutePass());
2354     // Remove ops for the input to _TPUCompileMlirOp, which are no longer needed
2355     // after CreateFuseTpuCompileAndExecutePass
2356     pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2357   }
2358 
2359   AddTfDeviceAssignmentPasses(pm, options);
2360 
2361   pm.addPass(CreateLowerTFSavedModelPass(options.hoist_invariant_ops));
2362 }
2363 
CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2364 void CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager &pm,
2365                                           const TfrtPipelineOptions &options) {
2366   CreateTFExecutorToTFPipeline(pm, options);
2367 
2368   pm.addPass(CreateTfToTfrtConversionPass(options));
2369 
2370   pm.addPass(CreateRemoveDeviceAttributePass());
2371 
2372   // Run optimizer on the MLIR module in CoreRT dialect.
2373   if (options.enable_optimizer) {
2374     pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
2375     pm.addPass(mlir::createInlinerPass());
2376     pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
2377     pm.addNestedPass<mlir::func::FuncOp>(
2378         tfrt_compiler::CreateInsertFallbackTensorCopyPass());
2379   }
2380 }
2381 
2382 // If verbose logging is on, dump the output of each pass to a file directory,
2383 // set via env var TF_DUMP_GRAPH_PREFIX. e.g.:
2384 // export TF_DUMP_GRAPH_PREFIX=/tmp/mlir
CreateTfExecutorToTfrtPipeline(mlir::PassManager & pm,const TfrtPipelineOptions & options)2385 void CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm,
2386                                     const TfrtPipelineOptions &options) {
2387   if (VLOG_IS_ON(1)) {
2388     // Print the whole module after each pass, which requires disabling
2389     // multi-threading as well.
2390     pm.getContext()->disableMultithreading();
2391     pm.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
2392         /*print_module_scope=*/true));
2393   }
2394   CreateTfExecutorToTfrtPipelineHelper(pm, options);
2395 }
2396 
2397 static mlir::PassRegistration<TfToTfrtConversionPass> tf_to_tfrt_pass;
2398 
2399 static mlir::PassRegistration<OutlineJitRtClustersPass>
2400     tf_outline_jitrt_cluster_pass(CreateOutlineJitRtClustersPass);
2401 
2402 static mlir::PassPipelineRegistration<TfrtPipelineOptions> tf_pipeline(
2403     "tf-executor-to-tfrt-pipeline",
2404     "Convert Tensorflow Executor dialect to TFRT dialect and "
2405     "also apply necessary optimization passes.",
2406     CreateTfExecutorToTfrtPipelineHelper);
2407 
2408 }  // namespace tensorflow
2409