• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements lowering of TF dialect to TFRT CoreRuntime ExecuteOp.
17 // This lowering pass is heavily experimental and incomplete. External code
18 // should not depend on the code here. And please do not take example on it as
19 // "the path forward" for this.
20 
21 #include <vector>
22 
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Dialect.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Pass/PassManager.h"
30 #include "mlir/Pass/PassOptions.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "mlir/Transforms/Passes.h"
33 #include "mlir/Transforms/RegionUtils.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
39 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h"
47 #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h"
48 #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_cpurt_ops.h"
49 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.h"
50 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.h"
51 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
52 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
53 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
54 #include "tensorflow/core/framework/tensor.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/platform/tstring.h"
57 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
58 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
59 #include "tfrt/cpu/jit/opdefs/cpurt_ops.h"  // from @tf_runtime
60 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
61 #include "tfrt/basic_kernels/opdefs/tfrt_base.h"  // from @tf_runtime
62 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
63 #include "tfrt/core_runtime/opdefs/attributes.h"  // from @tf_runtime
64 #include "tfrt/core_runtime/opdefs/core_runtime.h"  // from @tf_runtime
65 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
66 #include "tfrt/distributed_runtime/opdefs/kernels.h"  // from @tf_runtime
67 #include "tfrt/distributed_runtime/opdefs/types.h"  // from @tf_runtime
68 #include "tfrt/test_kernels/opdefs/test_kernels.h"  // from @tf_runtime
69 
70 namespace tensorflow {
71 namespace {
72 
73 constexpr unsigned kFallbackBenefit = 1;
74 constexpr unsigned kCoreRTBenefit = 2;
75 constexpr char kGpuDeviceName[] =
76     "/job:localhost/replica:0/task:0/device:GPU:0";
77 constexpr char kCpuDeviceName[] =
78     "/job:localhost/replica:0/task:0/device:CPU:0";
79 constexpr char kTFDeviceAttr[] = "tf.device";
80 constexpr char kTFRTDeviceAttr[] = "tfrt.device";
81 constexpr char kDeviceAttr[] = "device";
82 constexpr char kHostAttr[] = "host";
83 constexpr char kDeviceTypeTpu[] = "TPU";
84 
getDependentConversionDialects(mlir::DialectRegistry & registry)85 void getDependentConversionDialects(mlir::DialectRegistry &registry) {
86   registry.insert<tfrt::corert::CoreRTDialect,
87                   tfrt::fallback_async::FallbackAsyncDialect,
88                   tfrt::compiler::TFRTDialect, tfrt::dist::DistributedDialect,
89                   tf_cpurt::CpuRuntimeDialect>();
90 }
91 
GetFunctionInputChain(mlir::Operation * op)92 mlir::Value GetFunctionInputChain(mlir::Operation *op) {
93   auto func_op = op->getParentOfType<mlir::FuncOp>();
94   return func_op.getArgument(0);
95 }
96 
97 // Convert TF dialect ops to tfrt_fallback.executeop for non-side-effecting ops
98 // and tfrt_fallback.executeop.seq for side-effecting ops.
99 //
100 // For example,
101 //
102 // %0 = "tf.MatMul"(%arg0, %arg1) {device = "cpu"} :
103 //    (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
104 //
105 // is converted to
106 //
107 // %result = tfrt_fallback.executeop device("cpu")
108 //    "tf.MatMul"(%arg0, %arg1) : 1
109 //
110 class FallbackExecuteOpConversion : public mlir::ConversionPattern {
111  public:
FallbackExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,const tfrt_compiler::CostAnalysis * cost_analysis,bool tpu_lower_to_fallback)112   FallbackExecuteOpConversion(
113       mlir::MLIRContext *context, CoreRTConverter *corert_converter,
114       tfrt_compiler::FallbackConverter *fallback_converter,
115       const tfrt_compiler::CostAnalysis *cost_analysis,
116       bool tpu_lower_to_fallback)
117       : mlir::ConversionPattern(mlir::Pattern::MatchAnyOpTypeTag(),
118                                 kFallbackBenefit, context),
119         corert_converter_(*corert_converter),
120         fallback_converter_(*fallback_converter),
121         cost_analysis_(*cost_analysis),
122         tpu_lower_to_fallback_(tpu_lower_to_fallback) {}
123 
matchAndRewrite(mlir::Operation * op,ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const124   LogicalResult matchAndRewrite(
125       mlir::Operation *op, ArrayRef<mlir::Value> operands,
126       mlir::ConversionPatternRewriter &rewriter) const override {
127     if (!UseFallback(op)) return failure();
128 
129     corert_converter_.MaterializeDerivedAttributes(op);
130 
131     mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
132     if (!device || device.getValue().empty())
133       return op->emitWarning("failed to find a non-empty 'device' attribute");
134     op->removeAttr(kDeviceAttr);
135     auto parsed_device_name =
136         corert_converter_.ParseDeviceName(device.getValue());
137     if (!parsed_device_name)
138       return op->emitWarning("failed to parse the device name");
139 
140     // Convert the function (symbol) attributes to an array of string
141     // attributes, which represents the function names.
142     llvm::SmallVector<mlir::Identifier, 4> func_attr_keys;
143     mlir::ArrayAttr op_func_attrs =
144         corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
145 
146     // Remove the function attributes, which have already been processed.
147     for (const auto &key : func_attr_keys) op->removeAttr(key);
148 
149     mlir::ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
150     if (!op_attrs) return op->emitWarning("failed to lower attributes.");
151 
152     mlir::StringAttr op_name =
153         rewriter.getStringAttr(op->getName().getStringRef());
154 
155     // Currently the fallback executeop does not support GPU device. For GPU
156     // device, we still lower it to corert executeop where more devices are
157     // supported.
158     //
159     // TODO(b/166705169): Once GPU device are supported in fallback, we can
160     // remove the conversion to corert.
161     if (parsed_device_name->device_type == DEVICE_GPU ||
162         (parsed_device_name->device_type == kDeviceTypeTpu &&
163          !tpu_lower_to_fallback_)) {
164       return ConvertToCoreRTExecuteOp(
165           op, operands, parsed_device_name->op_handler_name, op_attrs,
166           op_func_attrs, op_name, rewriter);
167     }
168 
169     // For DEVICE_CPU, DEVICE_DEFAULT, and DEVICE_TPU_SYSTEM, we use fallback.
170     // Note that TPU bridge should handle all ops that are required to be
171     // executed on TPU. So if there are still ops that are placed on TPU at this
172     // stage of lowering TF to TFRT, then these ops are supposed to be executed
173     // on host.
174     return ConvertToFallbackExecuteOp(op, operands, device, op_attrs,
175                                       op_func_attrs, op_name,
176                                       fallback_converter_, rewriter);
177   }
178 
179  private:
180   // Return true if this op can be lowered to fallback ops.
UseFallback(mlir::Operation * op) const181   bool UseFallback(mlir::Operation *op) const {
182     if (!llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) return false;
183 
184     // Below is a blocklist of ops that should not go through CoreRTExecuteOp
185     // conversion.
186     // TODO(b/173017701): have a centralized place to hold the information
187     // whether a TF op should be lowered to FallbackExecute op.
188     return !llvm::isa<
189         mlir::TF::_TfrtSetResourceOp, mlir::TF::_TfrtGetResourceOp,
190         mlir::TF::_TPUCompileMlirOp, mlir::TF::TPUCompileSucceededAssertOp,
191         mlir::TF::TPUExecuteOp,
192         // Specifically handle control flow ops.
193         mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp,
194         mlir::TF::StatefulPartitionedCallOp, mlir::TF::PartitionedCallOp>(op);
195   }
196 
197   mlir::LogicalResult ConvertToFallbackExecuteOp(
198       mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
199       mlir::StringAttr device, mlir::ArrayAttr op_attrs,
200       mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
201       tfrt_compiler::FallbackConverter &fallback_converter,
202       mlir::ConversionPatternRewriter &rewriter) const;
203 
204   mlir::LogicalResult ConvertToCoreRTExecuteOp(
205       mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
206       llvm::StringRef op_handler_name, mlir::ArrayAttr op_attrs,
207       mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
208       mlir::ConversionPatternRewriter &rewriter) const;
209 
210   CoreRTConverter &corert_converter_;
211   tfrt_compiler::FallbackConverter &fallback_converter_;
212   const tfrt_compiler::CostAnalysis &cost_analysis_;
213   bool tpu_lower_to_fallback_;
214 };
215 
ConvertToFallbackExecuteOp(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,mlir::StringAttr device,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,tfrt_compiler::FallbackConverter & fallback_converter,mlir::ConversionPatternRewriter & rewriter) const216 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp(
217     mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
218     mlir::StringAttr device, mlir::ArrayAttr op_attrs,
219     mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
220     tfrt_compiler::FallbackConverter &fallback_converter,
221     mlir::ConversionPatternRewriter &rewriter) const {
222   llvm::SmallVector<Type, 4> result_types(
223       op->getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
224 
225   // Convert the operands to tfrt_fallback.tf_tensor if needed.
226   llvm::SmallVector<mlir::Value, 4> new_operands;
227   if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
228           op, device.getValue(), operands, &new_operands, rewriter)))
229     return failure();
230 
231   auto fallback_key =
232       rewriter.getI64IntegerAttr(fallback_converter.GetNextFallbackKey());
233 
234   // Query cost analysis to assign costs.
235   auto cost = rewriter.getI64IntegerAttr(cost_analysis_.GetCost(op));
236 
237   if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
238     auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
239         op->getLoc(), result_types, new_operands, device, op_attrs,
240         op_func_attrs, fallback_key, op_name, cost);
241     fallback_converter.RegisterFallbackOp(new_op);
242     rewriter.replaceOp(op, new_op.results());
243   } else {
244     auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
245     auto out_chain = in_chain;
246 
247     if (tfrt_compiler::IsTensorArrayOp(op)) {
248       // If it is a tensor array op, we don't need to use
249       // tfrt_fallback_async.executeop.seq because its operands/results already
250       // take care of control dependencies.
251       auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
252           op->getLoc(), result_types, new_operands, device, op_attrs,
253           op_func_attrs, fallback_key, op_name, cost);
254       fallback_converter.RegisterFallbackOp(new_op);
255       rewriter.replaceOp(op, new_op.results());
256     } else {
257       // Create tfrt_fallback.executeop.seq if it is a side-effecting op.
258       auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOpSeq>(
259           op->getLoc(), corert_converter_.chain_type(), result_types, in_chain,
260           new_operands, device, op_attrs, op_func_attrs, fallback_key, op_name,
261           cost);
262       fallback_converter.RegisterFallbackOp(new_op);
263       rewriter.replaceOp(op, new_op.results());
264       out_chain = new_op.out_op_chain();
265     }
266 
267     // Register the converted op so that it can be retrieved by successors.
268     corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
269   }
270 
271   return success();
272 }
273 
ConvertToCoreRTExecuteOp(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,llvm::StringRef op_handler_name,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,mlir::ConversionPatternRewriter & rewriter) const274 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToCoreRTExecuteOp(
275     mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
276     llvm::StringRef op_handler_name, mlir::ArrayAttr op_attrs,
277     mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
278     mlir::ConversionPatternRewriter &rewriter) const {
279   llvm::SmallVector<Type, 4> result_types(
280       op->getNumResults(), rewriter.getType<tfrt::corert::TensorHandleType>());
281 
282   // Convert the operands to tensorhandles if needed.
283   llvm::SmallVector<mlir::Value, 4> new_operands;
284   if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
285           op, operands, &new_operands, rewriter)))
286     return failure();
287 
288   // Get the op handler, or create one if there does not exist one. Note that
289   // ConvertOpHandler changes internal state so it can only be called if the
290   // rewrite is guaranteed to succeed afterwards.
291   auto op_handler =
292       corert_converter_.ConvertOpHandler(op, op_handler_name, &rewriter);
293   if (!op_handler) return failure();
294 
295   if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
296     auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
297         op->getLoc(), result_types, op_handler, new_operands, op_attrs,
298         op_func_attrs, op_name);
299     rewriter.replaceOp(op, new_op.results());
300   } else {
301     // Create corert.executeop.seq if it is a side-effecting op.
302     auto new_op = rewriter.create<tfrt::corert::ExecuteOpSeq>(
303         op->getLoc(), corert_converter_.chain_type(), result_types, op_handler,
304         corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
305         op_attrs, op_func_attrs, op_name);
306     rewriter.replaceOp(op, new_op.results());
307 
308     // Register the converted op so that it can be retrieved by successors.
309     corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
310   }
311 
312   return success();
313 }
314 
315 class FallbackConstOpConversion
316     : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
317  public:
FallbackConstOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)318   FallbackConstOpConversion(mlir::MLIRContext *context,
319                             CoreRTConverter *corert_converter)
320       : mlir::OpConversionPattern<mlir::TF::ConstOp>(context),
321         corert_converter_(*corert_converter) {}
322 
matchAndRewrite(mlir::TF::ConstOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const323   mlir::LogicalResult matchAndRewrite(
324       mlir::TF::ConstOp op, llvm::ArrayRef<mlir::Value> operands,
325       mlir::ConversionPatternRewriter &rewriter) const override {
326     // Some data types are handled separately using a fast path.
327     if (corert_converter_.IsSupportedNumericDType(op.dtype()) ||
328         op.dtype().isa<mlir::TF::StringType>())
329       return failure();
330 
331     // For other data types that do not have a fast path (eg. quantized types),
332     // we convert it to serialized tensor proto.
333 
334     tensorflow::TensorProto tensor_proto;
335     auto status = ConvertToTensorProto(op.value(), &tensor_proto);
336     if (!status.ok()) return op.emitError(status.error_message());
337 
338     rewriter.replaceOpWithNewOp<tfrt::fallback_async::ConstTensorProtoOp>(
339         op, rewriter.getType<tfrt::fallback::TFTensorType>(),
340         tensor_proto.SerializeAsString());
341 
342     return success();
343   }
344 
345  private:
346   CoreRTConverter &corert_converter_;
347 };
348 
349 class FallbackSetResourceOp
350     : public mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp> {
351  public:
FallbackSetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)352   FallbackSetResourceOp(mlir::MLIRContext *context,
353                         CoreRTConverter *corert_converter)
354       : mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp>(context),
355         corert_converter_(*corert_converter) {}
356 
matchAndRewrite(mlir::TF::_TfrtSetResourceOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const357   mlir::LogicalResult matchAndRewrite(
358       mlir::TF::_TfrtSetResourceOp op, llvm::ArrayRef<mlir::Value> operands,
359       mlir::ConversionPatternRewriter &rewriter) const override {
360     mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
361     if (!device || device.getValue().empty())
362       return op->emitWarning("failed to find a non-empty 'device' attribute");
363 
364     // Currently the static resource is always on host CPU.
365     //
366     // TODO(chky): Support resource on other devices.
367     llvm::SmallVector<mlir::Value, 4> new_operands;
368     if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
369             op, kCpuDeviceName, operands, &new_operands, rewriter)))
370       return failure();
371 
372     assert(new_operands.size() == 1);
373 
374     auto new_op = rewriter.create<tfrt::fallback_async::SetResourceOp>(
375         op.getLoc(), corert_converter_.chain_type(),
376         corert_converter_.GetLocalSideEffectChain(op, &rewriter),
377         new_operands[0], device.getValue(), op.index());
378 
379     // Register the converted op so that it can be retrieved by successors.
380     corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_ch());
381 
382     rewriter.eraseOp(op);
383 
384     return success();
385   }
386 
387  private:
388   CoreRTConverter &corert_converter_;
389 };
390 
391 class FallbackGetResourceOp
392     : public mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp> {
393  public:
FallbackGetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)394   FallbackGetResourceOp(mlir::MLIRContext *context,
395                         CoreRTConverter *corert_converter)
396       : mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp>(context),
397         corert_converter_(*corert_converter) {}
398 
matchAndRewrite(mlir::TF::_TfrtGetResourceOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const399   mlir::LogicalResult matchAndRewrite(
400       mlir::TF::_TfrtGetResourceOp op, llvm::ArrayRef<mlir::Value> operands,
401       mlir::ConversionPatternRewriter &rewriter) const override {
402     mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
403     if (!device || device.getValue().empty())
404       return op->emitWarning("failed to find a non-empty 'device' attribute");
405 
406     llvm::SmallVector<mlir::Type, 4> result_types(
407         op.getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
408 
409     auto new_op = rewriter.create<tfrt::fallback_async::GetResourceOp>(
410         op.getLoc(), corert_converter_.chain_type(), result_types,
411         corert_converter_.GetLocalSideEffectChain(op, &rewriter),
412         device.getValue(), op.indices());
413 
414     // Register the converted op so that it can be retrieved by successors.
415     corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_ch());
416 
417     rewriter.replaceOp(op, new_op.results());
418 
419     return success();
420   }
421 
422  private:
423   CoreRTConverter &corert_converter_;
424 };
425 
426 // Convert a tf_device.remote_run op to a tfrt_dist.remote_execute_func op.
427 //
428 // For example,
429 //
430 // %result = tf_device.remote_run "/job:worker/replica:0/task:0"
431 //   @remote_func(%arg0) : (tensor<i32>) -> (tensor<i32>)
432 //
433 // is converted to
434 //
435 // %0 = tfrt_dist.get_distributed_context
436 // %1 = tfrt_dist.get_task_handle %0
437 //  {task_name = "/job:worker/replica:0/task:0"}
438 // %2 = tfrt_dist.test_create_remote_chain_manager %0
439 // %3 = tfrt_dist.get_chain_for_task_handle %in_chain, %2, %1
440 // %out_op_chain, %results:2 = tfrt_dist.remote_execute_func[%in_chain, %0, %1]
441 //  @remote_func(%3, %arg1): (!tfrt_dist.remote_object_id, !corert.tensorhandle)
442 //  -> (!tfrt_dist.remote_object_id, !corert.tensorhandle)
443 // %4 = tfrt_dist.set_chain_for_task_handle %out_op_chain, %2, %1, %results#0
444 //
445 class TFDeviceRemoteRunOpConversion
446     : public mlir::OpConversionPattern<tf_device::RemoteRunOp> {
447  public:
TFDeviceRemoteRunOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter)448   TFDeviceRemoteRunOpConversion(mlir::MLIRContext *context,
449                                 mlir::TypeConverter *type_converter,
450                                 CoreRTConverter *corert_converter)
451       : mlir::OpConversionPattern<tf_device::RemoteRunOp>(context,
452                                                           kCoreRTBenefit),
453         type_converter_(*type_converter),
454         corert_converter_(*corert_converter) {}
455 
matchAndRewrite(tf_device::RemoteRunOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const456   LogicalResult matchAndRewrite(
457       tf_device::RemoteRunOp op, ArrayRef<mlir::Value> operands,
458       ConversionPatternRewriter &rewriter) const override {
459     mlir::Value distributed_context =
460         corert_converter_.GetDistributedContext(op.getOperation(), &rewriter);
461     mlir::Value in_op_chain =
462         corert_converter_.GetLocalSideEffectChain(op, &rewriter);
463     mlir::Value task_handle = corert_converter_.GetTaskHandle(
464         op.getOperation(), op.host(), &rewriter);
465     mlir::Value remote_chain_mgr =
466         corert_converter_.GetRemoteChainManager(op, &rewriter);
467     mlir::Type remote_obj_id_ty =
468         rewriter.getType<tfrt::dist::RemoteObjectIdType>();
469     ModuleOp module = op->getParentOfType<ModuleOp>();
470     SymbolTable symtab(module);
471     FuncOp callee = symtab.lookup<FuncOp>(op.callee());
472     if (!callee) {
473       op.emitOpError("callee function ") << op.callee() << " is not found";
474       return failure();
475     }
476     StringAttr host = callee->getAttrOfType<StringAttr>(kHostAttr);
477     if (!host) {
478       op.emitOpError("callee function ")
479           << op.callee() << " should have the host attribute";
480       return failure();
481     }
482 
483     llvm::SmallVector<mlir::Value, 4> arguments;
484     // The first argument of the remote function should be a remote chain which
485     // is added to the function signature when it is lowered from TF dialect to
486     // TFRT dialect.
487     arguments.push_back(corert_converter_.GetRemoteSideEffectChain(
488         op, host.getValue(), &rewriter));
489     for (mlir::Value argument : op.callee_args()) {
490       arguments.push_back(argument);
491     }
492 
493     llvm::SmallVector<mlir::Type, 4> result_types;
494     // The first result of the remote function should be a remote chain which
495     // is added to the function signature when it is lowered from TF dialect to
496     // TFRT dialect.
497     result_types.push_back(remote_obj_id_ty);
498     for (mlir::Type type : op.getResultTypes()) {
499       (void)type_converter_.convertType(type, result_types);
500     }
501     auto remote_execute_func_op =
502         rewriter.create<tfrt::dist::RemoteExecuteFuncOp>(
503             op.getLoc(), corert_converter_.chain_type(), result_types,
504             in_op_chain, distributed_context, task_handle, op.callee(),
505             arguments);
506     rewriter.replaceOp(op, remote_execute_func_op.results().drop_front(1));
507 
508     auto set_chain_op = rewriter.create<tfrt::dist::SetChainForTaskHandleOp>(
509         op.getLoc(), corert_converter_.chain_type(),
510         remote_execute_func_op.out_op_chain(), remote_chain_mgr, task_handle,
511         remote_execute_func_op.results().front());
512     corert_converter_.RegisterLocalSideEffectChain(op,
513                                                    set_chain_op.out_op_chain());
514 
515     return success();
516   }
517 
518  private:
519   mlir::TypeConverter &type_converter_;
520   CoreRTConverter &corert_converter_;
521 };
522 
523 // Lowers a tf.BatchFunction to tfrt_fallback.batch_function.
524 class FallbackBatchFunctionOpConversion
525     : public mlir::OpConversionPattern<mlir::TF::BatchFunctionOp> {
526  public:
FallbackBatchFunctionOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)527   FallbackBatchFunctionOpConversion(mlir::MLIRContext *context,
528                                     CoreRTConverter *corert_converter)
529       : mlir::OpConversionPattern<mlir::TF::BatchFunctionOp>(context,
530                                                              kFallbackBenefit),
531         corert_converter_(*corert_converter) {}
532 
matchAndRewrite(mlir::TF::BatchFunctionOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const533   LogicalResult matchAndRewrite(
534       mlir::TF::BatchFunctionOp op, ArrayRef<mlir::Value> operands,
535       ConversionPatternRewriter &rewriter) const override {
536     corert_converter_.MaterializeDerivedAttributes(op);
537 
538     // Remove the device attribute for fallback, as currently fallback will
539     // select device automatically.
540     //
541     // TODO(chky): The device attribute should be passed explicitly. This can be
542     // once we change the kernel implementation to choose device based on
543     // attributes.
544     op->removeAttr(rewriter.getIdentifier(kDeviceAttr));
545 
546     SymbolRefAttr f = op.fAttr();
547 
548     llvm::SmallVector<NamedAttribute, 12> attr_array;
549     for (auto &key_and_value : op->getAttrs()) {
550       StringRef name = key_and_value.first;
551       if (name == "_output_shapes" || name == "f") {
552         continue;
553       }
554       attr_array.push_back(key_and_value);
555     }
556     ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(attr_array);
557     if (!op_attrs) return op.emitWarning("failed to lower attributes.");
558 
559     llvm::SmallVector<Type, 4> result_types;
560     for (auto type : op.getResultTypes()) {
561       if (failed(corert_converter_.convertType(type, result_types)))
562         return failure();
563     }
564 
565     llvm::SmallVector<mlir::Value, 4> new_operands;
566     if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
567             op, operands, &new_operands, rewriter)))
568       return failure();
569 
570     auto new_op = rewriter.create<tfrt::fallback_async::BatchFunctionOp>(
571         op.getLoc(), corert_converter_.chain_type(), result_types,
572         corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
573         f, op_attrs);
574     rewriter.replaceOp(op, new_op.results());
575 
576     // Register the converted op so that it can be retrieved by successors.
577     corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
578 
579     return success();
580   }
581 
582  private:
583   CoreRTConverter &corert_converter_;
584 };
585 
586 // Lower a tf.Const op that creates a string tensor to a native
587 // corert.create_string_tensor op.
588 class CoreRTConstDenseTensorOpConversion
589     : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
590  public:
CoreRTConstDenseTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)591   CoreRTConstDenseTensorOpConversion(mlir::MLIRContext *context,
592                                      CoreRTConverter *corert_converter)
593       : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
594         corert_converter_(*corert_converter) {}
595 
matchAndRewrite(mlir::TF::ConstOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const596   LogicalResult matchAndRewrite(
597       mlir::TF::ConstOp op, ArrayRef<mlir::Value> operands,
598       ConversionPatternRewriter &rewriter) const override {
599     if (!corert_converter_.IsSupportedNumericDType(op.dtype()))
600       return failure();
601 
602     // Only CPU ops can be lowered using this conversion. If there is no device
603     // assignment, this op is treated as a CPU op and can be lowered.
604     if (auto parsed_device_name = corert_converter_.ParseDeviceName(op))
605       if (parsed_device_name->device_type != DEVICE_CPU) return failure();
606 
607     auto new_op = rewriter.create<tfrt::corert::ConstDenseTensorOp>(
608         op.getLoc(), corert_converter_.tensor_handle_type(),
609         op.value().cast<DenseElementsAttr>());
610     rewriter.replaceOp(op, new_op->getResult(0));
611     return success();
612   }
613 
614  private:
615   CoreRTConverter &corert_converter_;
616 };
617 
618 // Convert the FuncOp with the following changes to meet TFRT's requirements:
619 // 1) Convert types for the arguments and the results.
620 // 2) Add a chain to the arguments.
621 // 3) Add a chain to the results for side-effects.
622 // 4) If any argument has a tf.device attribute, change the attribute name
623 //    to tfrt.device.
624 // 5) If any result has a tf.device attribute, change the attribute name
625 //    to tfrt.device.
626 //
627 // The input chain is used to signal visibility of all side-effects before
628 // calling this function. The output chain is used to signal visibility of all
629 // side-effects of this function.
630 class TFRTFuncOpSignatureConversion
631     : public mlir::OpConversionPattern<mlir::FuncOp> {
632  public:
TFRTFuncOpSignatureConversion(mlir::MLIRContext * ctx,mlir::TypeConverter * type_converter)633   TFRTFuncOpSignatureConversion(mlir::MLIRContext *ctx,
634                                 mlir::TypeConverter *type_converter)
635       : OpConversionPattern(ctx), type_converter_(*type_converter) {}
636 
matchAndRewrite(mlir::FuncOp func_op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const637   LogicalResult matchAndRewrite(
638       mlir::FuncOp func_op, llvm::ArrayRef<Value> operands,
639       ConversionPatternRewriter &rewriter) const override {
640     mlir::FunctionType type = func_op.getType();
641 
642     // Convert the original function arguments.
643     mlir::TypeConverter::SignatureConversion converted_signature(
644         type.getNumInputs());
645     // Add a chain as the first input.
646     converted_signature.addInputs(
647         rewriter.getType<tfrt::compiler::ChainType>());
648 
649     // Convert the original function results.
650     SmallVector<Type, 1> converted_results;
651     // Add a chain as the first result.
652     converted_results.push_back(rewriter.getType<tfrt::compiler::ChainType>());
653 
654     if (failed(type_converter_.convertSignatureArgs(type.getInputs(),
655                                                     converted_signature)) ||
656         failed(type_converter_.convertTypes(type.getResults(),
657                                             converted_results)) ||
658         failed(rewriter.convertRegionTypes(&func_op.getBody(), type_converter_,
659                                            &converted_signature))) {
660       return failure();
661     }
662 
663     llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs;
664     // The first input, which is a chain added by this pass, has no attribute.
665     arg_attrs.emplace_back();
666     func_op.getAllArgAttrs(arg_attrs);
667     // If any argument has a tf.device attribute, change the attribute name to
668     // tfrt.device.
669     for (auto &arg_attr : arg_attrs) {
670       mlir::NamedAttrList arg_attr_list(arg_attr);
671       if (Attribute device = arg_attr_list.erase(kTFDeviceAttr)) {
672         arg_attr_list.set(kTFRTDeviceAttr, device);
673         arg_attr = arg_attr_list.getDictionary(device.getContext());
674       }
675     }
676     arg_attrs.resize(converted_signature.getConvertedTypes().size());
677 
678     // The first result, which is a chain added by this pass, has no attribute.
679     llvm::SmallVector<mlir::DictionaryAttr, 4> result_attrs;
680     result_attrs.emplace_back();
681     func_op.getAllResultAttrs(result_attrs);
682     // If any result has a tf.device attribute, change the attribute name to
683     // tfrt.device.
684     for (auto &result_attr : result_attrs) {
685       mlir::NamedAttrList result_attr_list(result_attr);
686       if (Attribute device = result_attr_list.erase(kTFDeviceAttr)) {
687         result_attr_list.set(kTFRTDeviceAttr, device);
688         result_attr = result_attr_list.getDictionary(device.getContext());
689       }
690     }
691     result_attrs.resize(converted_results.size());
692 
693     // Update the function signature in-place.
694     rewriter.updateRootInPlace(func_op, [&] {
695       func_op.setType(mlir::FunctionType::get(
696           func_op.getContext(), converted_signature.getConvertedTypes(),
697           converted_results));
698       func_op.setAllArgAttrs(arg_attrs);
699       func_op.setAllResultAttrs(result_attrs);
700     });
701 
702     return success();
703   }
704 
705  private:
706   mlir::TypeConverter &type_converter_;
707 };
708 
709 // Lower a tf.Const op that creates a string tensor to a native
710 // corert.create_string_tensor op.
711 class CoreRTConstStringTensorOpConversion
712     : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
713  public:
CoreRTConstStringTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)714   CoreRTConstStringTensorOpConversion(mlir::MLIRContext *context,
715                                       CoreRTConverter *corert_converter)
716       : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
717         corert_converter_(*corert_converter) {}
718 
matchAndRewrite(mlir::TF::ConstOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const719   LogicalResult matchAndRewrite(
720       mlir::TF::ConstOp op, ArrayRef<mlir::Value> operands,
721       ConversionPatternRewriter &rewriter) const override {  // NOLINT
722     if (!op.dtype().isa<mlir::TF::StringType>()) return failure();
723 
724     DenseStringElementsAttr attr = op.value().cast<DenseStringElementsAttr>();
725 
726     llvm::SmallVector<Attribute, 4> values;
727     values.reserve(attr.getNumElements());
728     for (const auto &element : attr.getRawStringData())
729       values.push_back(rewriter.getStringAttr(
730           llvm::StringRef(element.data(), element.size())));
731 
732     // Create the shape attribute from the tensor shape.
733     ArrayRef<int64_t> shape = op.value().getType().getShape();
734     llvm::SmallVector<mlir::Attribute, 4> dims;
735     dims.reserve(shape.size());
736     auto i64_type = rewriter.getIntegerType(64);
737     for (auto dim : shape)
738       dims.push_back(rewriter.getIntegerAttr(i64_type, dim));
739 
740     auto new_op = rewriter.create<tfrt::corert::ConstStringTensorOp>(
741         op.getLoc(), corert_converter_.tensor_handle_type(),
742         rewriter.getArrayAttr(dims), rewriter.getArrayAttr(values));
743 
744     rewriter.replaceOp(op, new_op.result());
745 
746     return success();
747   }
748 
749  private:
750   CoreRTConverter &corert_converter_;
751 };
752 
753 // Convert TF dialect operations with no side effects to CoreRT ExecuteOp. For
754 // example,
755 //
756 // %0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
757 //    (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
758 //
759 // is converted to
760 //
761 // %result = corert.executeop(%device)
762 //    "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
763 //    (!corert.tensorhandle, !corert.tensorhandle) -> !corert.tensorhandle
764 //
765 // Note that it will fail to match if some attributes are not supported.
766 template <typename TF_Op>
767 class CoreRTExecuteOpConversion : public mlir::OpConversionPattern<TF_Op> {
768  public:
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)769   CoreRTExecuteOpConversion(mlir::MLIRContext *context,
770                             CoreRTConverter *corert_converter)
771       : CoreRTExecuteOpConversion(context, corert_converter, "") {}
772 
773   // If device_name is not empty, only ops that are using this device is lowered
774   // using CoreRTExecuteOpConversion.
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,llvm::StringRef device_name)775   CoreRTExecuteOpConversion(mlir::MLIRContext *context,
776                             CoreRTConverter *corert_converter,
777                             llvm::StringRef device_name)
778       : mlir::OpConversionPattern<TF_Op>(context, kCoreRTBenefit),
779         corert_converter_(*corert_converter),
780         device_name_(device_name) {}
781 
matchAndRewrite(TF_Op op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const782   LogicalResult matchAndRewrite(
783       TF_Op op, ArrayRef<mlir::Value> operands,
784       ConversionPatternRewriter &rewriter) const override {
785     auto parsed_device_name = corert_converter_.ParseDeviceName(op);
786     // Return failure and emit warning if there is no device assignment.
787     if (!parsed_device_name) {
788       return op->emitWarning(
789           "failed to retrieve valid device when converting to "
790           "corert.executeop");
791     }
792 
793     // If device_name is specified, check the device of this op first.
794     if (!device_name_.empty()) {
795       // Skip if it does not match the specified device.
796       if (parsed_device_name->device_name != device_name_) return failure();
797     }
798 
799     mlir::StringAttr op_name = rewriter.getStringAttr(op.getOperationName());
800 
801     llvm::SmallVector<Type, 4> result_types;
802     for (auto type : op.getOperation()->getResultTypes()) {
803       if (failed(corert_converter_.convertType(type, result_types)))
804         return failure();
805     }
806 
807     corert_converter_.MaterializeDerivedAttributes(op);
808 
809     // Convert the function (symbol) attributes to an array of string
810     // attributes, which represents the function names.
811     llvm::SmallVector<mlir::Identifier, 4> func_attr_keys;
812     ArrayAttr op_func_attrs =
813         corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
814 
815     // Remove the function attributes, which have already been processed.
816     for (const auto &key : func_attr_keys) op->removeAttr(key);
817 
818     ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
819     if (!op_attrs) return op.emitError("failed to lower attributes.");
820 
821     llvm::SmallVector<mlir::Value, 4> new_operands;
822     if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
823             op, operands, &new_operands, rewriter)))
824       return failure();
825 
826     // Get the op handler, or create one if there does not exist one. Note that
827     // ConvertOpHandler changes internal state so it can only be called if the
828     // rewrite is guaranteed to succeed afterwards.
829     auto op_handler = corert_converter_.ConvertOpHandler(
830         op, parsed_device_name->op_handler_name, &rewriter);
831     if (!op_handler) return failure();
832 
833     auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
834         op.getLoc(), result_types, op_handler, new_operands, op_attrs,
835         op_func_attrs, op_name);
836 
837     rewriter.replaceOp(op, new_op.results());
838     return success();
839   }
840 
841  private:
842   CoreRTConverter &corert_converter_;
843   llvm::StringRef device_name_;
844 };
845 
ConvertFunctionCallOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter,bool func_use_fallback_tensor)846 LogicalResult ConvertFunctionCallOperands(
847     mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
848     llvm::SmallVectorImpl<mlir::Value> *new_operands,
849     mlir::ConversionPatternRewriter &rewriter, bool func_use_fallback_tensor) {
850   if (func_use_fallback_tensor) {
851     // TODO(b/182232457): Support other devices.
852     return tfrt_compiler::ConvertFallbackOperands(op, kCpuDeviceName, operands,
853                                                   new_operands, rewriter);
854   } else {
855     return tfrt_compiler::ConvertCoreRTOperands(op, operands, new_operands,
856                                                 rewriter);
857   }
858 }
859 
860 // Convert TF call ops (eg. StatefulPartitionedCall) to tfrt.call.
861 template <typename CallOp>
862 class TFRTCallOpConversion : public mlir::OpConversionPattern<CallOp> {
863  public:
TFRTCallOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)864   TFRTCallOpConversion(mlir::MLIRContext *context,
865                        mlir::TypeConverter *type_converter,
866                        CoreRTConverter *corert_converter,
867                        bool func_use_fallback_tensor)
868       : mlir::OpConversionPattern<CallOp>(context),
869         type_converter_(*type_converter),
870         corert_converter_(*corert_converter),
871         func_use_fallback_tensor_(func_use_fallback_tensor) {}
872 
matchAndRewrite(CallOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const873   LogicalResult matchAndRewrite(
874       CallOp op, ArrayRef<mlir::Value> operands,
875       ConversionPatternRewriter &rewriter) const override {
876     auto callee =
877         op.getCallableForCallee().template dyn_cast<mlir::SymbolRefAttr>();
878     if (!callee) return failure();
879 
880     llvm::SmallVector<mlir::Value, 4> new_operands;
881     new_operands.push_back(
882         corert_converter_.GetLocalSideEffectChain(op, &rewriter));
883     // Currently the converted functions always use !corert.tensorhandle types
884     // for tensor arguments and results.
885     //
886     // TODO(b/175881042): We should avoid the tensor conversion here if the
887     // operand is !tfrt_fallback.tf_tensor, and it is also used as fallback
888     // tensor inside the callee function.
889     if (mlir::failed(ConvertFunctionCallOperands(
890             op, operands, &new_operands, rewriter, func_use_fallback_tensor_)))
891       return failure();
892 
893     llvm::SmallVector<mlir::Type, 4> result_types;
894     result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
895     for (auto type : op.getOperation()->getResultTypes()) {
896       if (failed(type_converter_.convertType(type, result_types)))
897         return failure();
898     }
899 
900     auto new_op = rewriter.create<tfrt::compiler::CallOp>(
901         op.getLoc(), result_types, callee.getRootReference(), new_operands);
902     rewriter.replaceOp(op, new_op.getResults().drop_front());
903 
904     if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
905       // Register the converted op so that it can be retrieved by successors.
906       // TODO(chky): Add OpTraits or OpInterface, rather than assume first
907       // result is a chain.
908       corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
909     }
910 
911     return success();
912   }
913 
914  private:
915   mlir::TypeConverter &type_converter_;
916   CoreRTConverter &corert_converter_;
917   bool func_use_fallback_tensor_;
918 };
919 
920 // Convert standard ReturnOp to tfrt.return.
921 //
922 // TODO(chky): conversion to tfrt kernels should come from a common tf_to_tfrt
923 // library.
924 class TFRTReturnOpConversion
925     : public mlir::OpConversionPattern<mlir::ReturnOp> {
926  public:
TFRTReturnOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)927   TFRTReturnOpConversion(mlir::MLIRContext *context,
928                          CoreRTConverter *corert_converter,
929                          bool func_use_fallback_tensor)
930       : mlir::OpConversionPattern<mlir::ReturnOp>(context),
931         corert_converter_(*corert_converter),
932         func_use_fallback_tensor_(func_use_fallback_tensor) {}
933 
matchAndRewrite(mlir::ReturnOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const934   LogicalResult matchAndRewrite(
935       mlir::ReturnOp op, ArrayRef<mlir::Value> operands,
936       ConversionPatternRewriter &rewriter) const override {
937     llvm::SmallVector<mlir::Value, 2> new_operands;
938 
939     // Currently in mlir::TF::SideEffectAnalysis, all terminator ops are treated
940     // as side-effect ops and they have predecessors but not successors.
941     //
942     // TODO(chky): ReturnOp has no side effect. Once the special handling in
943     // mlir::TF::SideEffectAnalysis is removed, the chains should come from
944     // side-effecting ops with no successors in the function.
945     new_operands.push_back(
946         corert_converter_.GetLocalSideEffectChain(op, &rewriter));
947     if (mlir::failed(ConvertFunctionCallOperands(
948             op, operands, &new_operands, rewriter, func_use_fallback_tensor_)))
949       return failure();
950 
951     rewriter.replaceOpWithNewOp<tfrt::compiler::ReturnOp>(op, new_operands);
952     return success();
953   }
954 
955  private:
956   CoreRTConverter &corert_converter_;
957   bool func_use_fallback_tensor_;
958 };
959 
960 // Convert tf.Case op to tfrt.Case.
961 //
962 // TF dialect:
963 // %outputs = "tf.Case"(%arg, ...) { branches = [@branch0, @branch1], ...}
964 //
965 // lowered TFRT CoreRT dialect:
966 // %idx_int = corert.tensorhandle_to_int32 %idx
967 // %out_chain, %outputs = tfrt.case %idx_int [@branch0, @branch1] (%chain, ...)
968 class TFRTCaseOpConversion : public mlir::OpConversionPattern<TF::CaseOp> {
969  public:
TFRTCaseOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)970   TFRTCaseOpConversion(mlir::MLIRContext *context,
971                        mlir::TypeConverter *type_converter,
972                        CoreRTConverter *corert_converter,
973                        bool func_use_fallback_tensor)
974       : mlir::OpConversionPattern<TF::CaseOp>(context),
975         type_converter_(*type_converter),
976         corert_converter_(*corert_converter),
977         func_use_fallback_tensor_(func_use_fallback_tensor) {}
978 
matchAndRewrite(TF::CaseOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const979   LogicalResult matchAndRewrite(
980       TF::CaseOp op, ArrayRef<mlir::Value> operands,
981       ConversionPatternRewriter &rewriter) const override {
982     mlir::ArrayAttr branches = op.branches();
983 
984     llvm::SmallVector<mlir::Type, 4> result_types;
985     for (mlir::Type type : op->getResultTypes()) {
986       if (failed(type_converter_.convertType(type, result_types)))
987         return failure();
988     }
989 
990     llvm::SmallVector<mlir::Value, 4> branch_operands;
991     if (mlir::failed(ConvertFunctionCallOperands(op, operands.drop_front(),
992                                                  &branch_operands, rewriter,
993                                                  func_use_fallback_tensor_)))
994       return failure();
995 
996     mlir::Value index_operand = operands[0];
997     // TODO(b/182233401): Support TF tensor; remove the conversion op here.
998     if (index_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
999       // TODO(b/182232457): Support other devices.
1000       index_operand =
1001           rewriter
1002               .create<
1003                   tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
1004                   op.getLoc(),
1005                   rewriter.getType<tfrt::corert::TensorHandleType>(),
1006                   operands[0], kCpuDeviceName)
1007               .getResult(0);
1008     }
1009     if (!index_operand.getType().isa<tfrt::corert::TensorHandleType>())
1010       return op.emitError(
1011           "branch index operand is expected to be a TensorHandle.");
1012     mlir::Value index_value =
1013         rewriter.create<tfrt::corert::TensorHandleToIntOp>(
1014             op.getLoc(), rewriter.getI32Type(), index_operand);
1015 
1016     auto new_op = rewriter.create<tfrt::compiler::CaseOp>(
1017         op.getLoc(), corert_converter_.chain_type(), result_types, index_value,
1018         branches, corert_converter_.GetLocalSideEffectChain(op, &rewriter),
1019         branch_operands);
1020 
1021     rewriter.replaceOp(op, new_op.branch_outputs());
1022     return success();
1023   }
1024 
1025  private:
1026   mlir::TypeConverter &type_converter_;
1027   CoreRTConverter &corert_converter_;
1028   bool func_use_fallback_tensor_;
1029 };
1030 
GetPredicate(mlir::Operation * op,mlir::Value cond_operand,mlir::ConversionPatternRewriter & rewriter)1031 static mlir::Value GetPredicate(mlir::Operation *op, mlir::Value cond_operand,
1032                                 mlir::ConversionPatternRewriter &rewriter) {
1033   if (!cond_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
1034     cond_operand = tfrt_compiler::ConvertCoreRTTensorHandleToFallbackTensor(
1035         op->getLoc(), kCpuDeviceName, cond_operand, rewriter);
1036     if (!cond_operand) {
1037       op->emitWarning("failed to convert the cond operand to fallback tensor.");
1038       return {};
1039     }
1040   }
1041 
1042   return rewriter.create<tfrt::fallback_async::PredicateOp>(
1043       op->getLoc(), rewriter.getI1Type(), cond_operand);
1044 }
1045 
1046 class TFRTCondOpConversion : public mlir::OpConversionPattern<mlir::TF::IfOp> {
1047  public:
TFRTCondOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)1048   TFRTCondOpConversion(mlir::MLIRContext *context,
1049                        mlir::TypeConverter *type_converter,
1050                        CoreRTConverter *corert_converter,
1051                        bool func_use_fallback_tensor)
1052       : mlir::OpConversionPattern<TF::IfOp>(context),
1053         type_converter_(*type_converter),
1054         corert_converter_(*corert_converter),
1055         func_use_fallback_tensor_(func_use_fallback_tensor) {}
1056 
matchAndRewrite(mlir::TF::IfOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const1057   mlir::LogicalResult matchAndRewrite(
1058       mlir::TF::IfOp op, llvm::ArrayRef<mlir::Value> operands,
1059       mlir::ConversionPatternRewriter &rewriter) const override {
1060     mlir::FlatSymbolRefAttr then_branch = op.then_branchAttr();
1061     mlir::FlatSymbolRefAttr else_branch = op.else_branchAttr();
1062 
1063     llvm::SmallVector<Type, 4> result_types;
1064     result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
1065     for (mlir::Type type : op.getOperation()->getResultTypes()) {
1066       if (failed(type_converter_.convertType(type, result_types)))
1067         return failure();
1068     }
1069 
1070     // Convert the cond tensor to a boolean value so that it can be used by
1071     // tfrt.cond kernel.
1072     auto bool_cond = GetPredicate(op, operands[0], rewriter);
1073     if (!bool_cond) return failure();
1074 
1075     llvm::SmallVector<mlir::Value, 4> new_operands;
1076     // The first arg of converted branch functions should be !tfrt.chain.
1077     new_operands.push_back(
1078         corert_converter_.GetLocalSideEffectChain(op, &rewriter));
1079 
1080     if (mlir::failed(ConvertFunctionCallOperands(op, operands.drop_front(),
1081                                                  &new_operands, rewriter,
1082                                                  func_use_fallback_tensor_)))
1083       return failure();
1084 
1085     auto new_op = rewriter.create<tfrt::compiler::CondOp>(
1086         op.getLoc(), result_types, bool_cond, then_branch, else_branch,
1087         new_operands);
1088 
1089     // The first result is a !tfrt.chain.
1090     rewriter.replaceOp(op, new_op.getResults().drop_front(1));
1091 
1092     if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1093       // Register the converted op so that it can be retrieved by successors.
1094       // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1095       // result is a chain.
1096       corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
1097     }
1098 
1099     return success();
1100   }
1101 
1102  private:
1103   mlir::TypeConverter &type_converter_;
1104   CoreRTConverter &corert_converter_;
1105   bool func_use_fallback_tensor_;
1106 };
1107 
1108 // Convert TF WhileOp to tfrt.while. tfrt.while use a boolean condition and has
1109 // slightly different semantics from tf.While for performance and generality.
1110 // The pseudo code of tfrt.while is as follows:
1111 //
1112 //  while(cond) {
1113 //    outputs, cond = body(inputs)
1114 //    inputs = outputs
1115 //  }
1116 //  return outputs
1117 //
1118 // So we need to insert extra convertion kernels and merge functions when
1119 // lowering tf.While to tfrt.while.
1120 //
1121 //  %result = tf.While(%arg) {cond = @original_cond_fn, body =
1122 //  @original_body_fn}
1123 //
1124 // is converted to
1125 //
1126 //  func @new_pred_fn(%ch, %arg) {
1127 //    %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1128 //    %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1129 //    tfrt.return %ch0, %cond_bool
1130 //  }
1131 //
1132 //  func @new_while_body(%ch, %arg) {
1133 //    %ch0, %result = tfrt.call @original_body_fn(%ch, %arg)
1134 //    %ch1, %cond_bool = tfrt.call @new_pred_fn(%ch0, %result)
1135 //    tfrt.return %ch1, %result, %cond_bool
1136 //  }
1137 //
1138 //  %ch0, %first_iter_cond = tfrt.call @new_pred_fn(%ch, %arg)
1139 //  %ch1, %result = tfrt.while %first_iter_cond @new_while_body(%ch0, %arg)
1140 //
1141 class TFRTWhileOpConversion
1142     : public mlir::OpConversionPattern<mlir::TF::WhileOp> {
1143  public:
TFRTWhileOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool func_use_fallback_tensor)1144   TFRTWhileOpConversion(mlir::MLIRContext *context,
1145                         mlir::TypeConverter *type_converter,
1146                         CoreRTConverter *corert_converter,
1147                         mlir::SymbolTable *symbol_table,
1148                         const tfrt_compiler::TensorArraySideEffectAnalysis
1149                             *tensor_array_side_effect_analysis,
1150                         bool func_use_fallback_tensor)
1151       : mlir::OpConversionPattern<TF::WhileOp>(context),
1152         type_converter_(*type_converter),
1153         corert_converter_(*corert_converter),
1154         symbol_table_(*symbol_table),
1155         tensor_array_side_effect_analysis_(*tensor_array_side_effect_analysis),
1156         func_use_fallback_tensor_(func_use_fallback_tensor) {}
1157 
matchAndRewrite(mlir::TF::WhileOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const1158   mlir::LogicalResult matchAndRewrite(
1159       mlir::TF::WhileOp op, llvm::ArrayRef<mlir::Value> operands,
1160       mlir::ConversionPatternRewriter &rewriter) const override {
1161     mlir::FlatSymbolRefAttr cond_fn = op.condAttr();
1162     mlir::FlatSymbolRefAttr body_fn = op.bodyAttr();
1163     // TODO(hanbinyoon): Support the parallel_iterations attribute.
1164 
1165     llvm::SmallVector<Type, 4> while_arg_result_types;
1166     // Insert a chain for side effects as the first argument/result.
1167     while_arg_result_types.push_back(
1168         rewriter.getType<tfrt::compiler::ChainType>());
1169     for (mlir::Type type : op.getOperation()->getResultTypes()) {
1170       if (failed(type_converter_.convertType(type, while_arg_result_types)))
1171         return failure();
1172     }
1173 
1174     // Convert the operands to either fallback tensor or corert tensors as
1175     // specified in the option.
1176     llvm::SmallVector<mlir::Value, 4> new_operands;
1177     if (mlir::failed(ConvertFunctionCallOperands(
1178             op, operands, &new_operands, rewriter, func_use_fallback_tensor_)))
1179       return failure();
1180 
1181     // Create the predicate function that calls the original cond function and
1182     // in addition convert the result to a boolean value.
1183     mlir::FuncOp pred_fn =
1184         GetPredicateFunction(op, cond_fn, while_arg_result_types, rewriter);
1185     if (!pred_fn) return failure();
1186 
1187     auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
1188     auto out_chain = in_chain;
1189 
1190     bool has_at_most_tensor_array_effect = HasAtMostTensorArrayEffect(op);
1191 
1192     // Prepare the arguments to call the pred function for the first iteration.
1193     llvm::SmallVector<mlir::Value, 4> pred_args;
1194     pred_args.push_back(
1195         has_at_most_tensor_array_effect ? GetFunctionInputChain(op) : in_chain);
1196     pred_args.append(new_operands.begin(), new_operands.end());
1197 
1198     // Insert a call op to call the pred function for the first iteration.
1199     auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1200         op.getLoc(), pred_fn.getType().getResults(), pred_fn.sym_name(),
1201         pred_args);
1202 
1203     auto pred_chain = call_pred_fn.getResult(0);
1204     auto first_iteration_bool_cond = call_pred_fn.getResult(1);
1205 
1206     mlir::FuncOp new_body_fn = GetWhileBodyFunction(
1207         op, body_fn, pred_fn, while_arg_result_types, rewriter);
1208 
1209     // Use the pred chain as the chain to the while body. The rest args should
1210     // be the same as the pred_args.
1211     auto &while_args = pred_args;
1212     while_args[0] = pred_chain;
1213 
1214     auto new_op = rewriter.create<tfrt::compiler::WhileOp>(
1215         op.getLoc(), while_arg_result_types, first_iteration_bool_cond,
1216         while_args, new_body_fn.sym_name());
1217 
1218     rewriter.replaceOp(op, new_op.getResults().drop_front());
1219 
1220     if (!has_at_most_tensor_array_effect) out_chain = new_op.getResult(0);
1221 
1222     if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1223       // Register the converted op so that it can be retrieved by successors.
1224       // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1225       // result is a chain.
1226       corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
1227     }
1228     return success();
1229   }
1230 
1231  private:
HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const1232   bool HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const {
1233     return tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1234                op.cond_function()) &&
1235            tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1236                op.body_function());
1237   }
1238 
1239   mlir::FuncOp GetPredicateFunction(
1240       mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1241       mlir::TypeRange arg_types,
1242       mlir::ConversionPatternRewriter &rewriter) const;
1243 
1244   mlir::FuncOp GetWhileBodyFunction(
1245       mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr body_fn,
1246       mlir::FuncOp pred_fn, mlir::TypeRange arg_types,
1247       mlir::ConversionPatternRewriter &rewriter) const;
1248 
1249   mlir::TypeConverter &type_converter_;
1250   CoreRTConverter &corert_converter_;
1251   mlir::SymbolTable &symbol_table_;
1252   const tfrt_compiler::TensorArraySideEffectAnalysis
1253       &tensor_array_side_effect_analysis_;
1254   bool func_use_fallback_tensor_;
1255 };
1256 
1257 // Create the pred function that contains a call to the original cond function
1258 // and a predicate kernel that converts the cond tensor to a boolean value. eg.
1259 //
1260 // func @pred_fn(%ch, %arg) {
1261 //  %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1262 //  %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1263 //  return %ch0, %cond_bool
1264 // }
1265 //
GetPredicateFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr cond_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1266 mlir::FuncOp TFRTWhileOpConversion::GetPredicateFunction(
1267     mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1268     mlir::TypeRange arg_types,
1269     mlir::ConversionPatternRewriter &rewriter) const {
1270   std::string pred_fn_name = cond_fn.getValue().str() + "/tfrt_predicate";
1271 
1272   if (auto pred_fn = symbol_table_.lookup<mlir::FuncOp>(pred_fn_name)) {
1273     return pred_fn;
1274   }
1275 
1276   auto func_op = op->getParentOfType<mlir::FuncOp>();
1277 
1278   mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1279   rewriter.setInsertionPointAfter(func_op);
1280 
1281   std::array<mlir::Type, 2> pred_result_types = {
1282       rewriter.getType<tfrt::compiler::ChainType>(), rewriter.getI1Type()};
1283 
1284   auto func_type = rewriter.getFunctionType(arg_types, pred_result_types);
1285 
1286   auto pred_fn =
1287       rewriter.create<mlir::FuncOp>(op.getLoc(), pred_fn_name, func_type);
1288 
1289   auto *block = pred_fn.addEntryBlock();
1290   rewriter.setInsertionPointToStart(block);
1291 
1292   // There are at least two arguments, with the first being !tfrt.chain and the
1293   // second being either !tfrt_fallback.tf_tensor or !corert.tensorhandle.
1294   // cond_fn must have two results. The first of them must also be !tfrt.chain
1295   // and the second must have the same tensor type as arguments. So we can just
1296   // use the first two arg types as the result types.
1297   assert(arg_types.size() >= 2);
1298   auto cond_result_types = arg_types.take_front(2);
1299 
1300   auto call_cond_fn = rewriter.create<tfrt::compiler::CallOp>(
1301       op.getLoc(), cond_result_types, cond_fn, block->getArguments());
1302 
1303   auto chain = call_cond_fn.getResult(0);
1304   auto cond = call_cond_fn.getResult(1);
1305 
1306   auto bool_cond = GetPredicate(op, cond, rewriter);
1307   if (!bool_cond) return {};
1308 
1309   llvm::SmallVector<mlir::Value, 2> results = {chain, bool_cond};
1310 
1311   rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), results);
1312 
1313   symbol_table_.insert(pred_fn);
1314 
1315   return pred_fn;
1316 }
1317 
1318 // Create the new while body function that contains a call to original while
1319 // body and then a call to the pred function. eg.
1320 //
1321 // func @new_while_body(%ch, %arg) {
1322 //   %ch0, %result = tfrt.call @original_body(%ch, %arg)
1323 //   %ch1, %cond_bool = tfrt.call @pred_function(%ch0, %arg)
1324 //   tfrt.return %ch1, %result, %cond_bool
1325 // }
1326 //
GetWhileBodyFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr original_body_fn,mlir::FuncOp pred_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1327 mlir::FuncOp TFRTWhileOpConversion::GetWhileBodyFunction(
1328     mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr original_body_fn,
1329     mlir::FuncOp pred_fn, mlir::TypeRange arg_types,
1330     mlir::ConversionPatternRewriter &rewriter) const {
1331   std::string body_fn_name = original_body_fn.getValue().str() + "/tfrt_body";
1332 
1333   if (auto body_fn = symbol_table_.lookup<mlir::FuncOp>(body_fn_name)) {
1334     return body_fn;
1335   }
1336 
1337   auto func_op = op->getParentOfType<mlir::FuncOp>();
1338 
1339   mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1340   rewriter.setInsertionPointAfter(func_op);
1341 
1342   auto while_result_types = arg_types;
1343 
1344   llvm::SmallVector<mlir::Type, 4> body_result_types(arg_types.begin(),
1345                                                      arg_types.end());
1346 
1347   // The last result of the while body function is the boolean condition.
1348   body_result_types.push_back(rewriter.getI1Type());
1349   auto func_type = rewriter.getFunctionType(arg_types, body_result_types);
1350   auto body_fn =
1351       rewriter.create<mlir::FuncOp>(op.getLoc(), body_fn_name, func_type);
1352 
1353   auto *block = body_fn.addEntryBlock();
1354   rewriter.setInsertionPointToStart(block);
1355 
1356   // Insert a call to the original body function.
1357   auto call_original_body_fn = rewriter.create<tfrt::compiler::CallOp>(
1358       op.getLoc(), while_result_types, original_body_fn, block->getArguments());
1359 
1360   // Insert a call to the pred function, which contains a call to the original
1361   // cond function and the predicate kernel that converts the tensor to boolean
1362   // value.
1363   auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1364       op.getLoc(), pred_fn.getType().getResults(), pred_fn.sym_name(),
1365       call_original_body_fn.getResults());
1366 
1367   auto pred_chain = call_pred_fn.getResult(0);
1368 
1369   llvm::SmallVector<mlir::Value, 4> body_results;
1370   // The first result should be the chain returned from the pred function.
1371   body_results.push_back(pred_chain);
1372   // Then comes the results from the orignal while body excluding the first
1373   // chain (which is replaced by the `pred_chain`).
1374   for (auto res : llvm::drop_begin(call_original_body_fn.getResults())) {
1375     body_results.push_back(res);
1376   }
1377   // The last result should be the boolean value converted from the condition.
1378   auto bool_cond = call_pred_fn.getResult(1);
1379   body_results.push_back(bool_cond);
1380 
1381   rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), body_results);
1382 
1383   symbol_table_.insert(body_fn);
1384 
1385   return body_fn;
1386 }
1387 
1388 // TODO(ezhulenev): tf_device.cluster operations after auto-fusion should
1389 // have the correct device assigned based on the fused operations. We should
1390 // use this device to convert operands and results from/to corert handles.
1391 // For now it is safe to assume that it is "CPU" because we do not support
1392 // any other devices and do not support distributed models.
1393 constexpr char kCpuRtDevice[] = "CPU";
1394 
1395 // Convert cpurt.call operations to the tf_cpurt.fallback.execute operation.
1396 class CpuRtCallToCpurtCompileAndExecuteConversion
1397     : public OpConversionPattern<tfrt::cpu::jit::CallOp> {
1398  public:
CpuRtCallToCpurtCompileAndExecuteConversion(MLIRContext * context)1399   explicit CpuRtCallToCpurtCompileAndExecuteConversion(MLIRContext *context)
1400       : OpConversionPattern<tfrt::cpu::jit::CallOp>(context) {}
1401 
matchAndRewrite(tfrt::cpu::jit::CallOp call,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1402   LogicalResult matchAndRewrite(
1403       tfrt::cpu::jit::CallOp call, ArrayRef<Value> operands,
1404       ConversionPatternRewriter &rewriter) const override {
1405     // Convert operands to fallback tensors.
1406     llvm::SmallVector<Value, 4> fallback_operands;
1407     if (failed(tfrt_compiler::ConvertFallbackOperands(
1408             call, kCpuRtDevice, operands, &fallback_operands, rewriter)))
1409       return rewriter.notifyMatchFailure(call, "failed to convert operand");
1410 
1411     // tf_cpurt.fallback.execute always produces fallback tensors.
1412     llvm::SmallVector<Type, 4> result_types(
1413         call->getNumResults(),
1414         rewriter.getType<tfrt::fallback::TFTensorType>());
1415 
1416     // Replace cpurt.call operation with a tf_cpurt.fallback.execute operation.
1417     rewriter.replaceOpWithNewOp<tf_cpurt::FallbackExecuteOp>(
1418         call, result_types, call.callee(), fallback_operands, kCpuRtDevice);
1419 
1420     return success();
1421   }
1422 };
1423 
1424 // Helper function for specifying legal dialects for conversion to CoreRT.
SetUpTFToTFRTConversionLegality(mlir::ConversionTarget * target,mlir::TypeConverter * func_type_converter,mlir::Type chain_type)1425 void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target,
1426                                      mlir::TypeConverter *func_type_converter,
1427                                      mlir::Type chain_type) {
1428   target->addLegalDialect<tfrt::corert::CoreRTDialect>();
1429   target->addLegalDialect<tfrt::fallback_async::FallbackAsyncDialect>();
1430   target->addLegalDialect<tfrt::compiler::TFRTDialect>();
1431   target->addLegalDialect<tfrt::dist::DistributedDialect>();
1432   target->addLegalDialect<tfrt::test::TestDialect>();
1433   target->addLegalDialect<tf_cpurt::CpuRuntimeDialect>();
1434   target->addIllegalDialect<TF::TensorFlowDialect>();
1435   target->addIllegalDialect<tf_device::TensorFlowDeviceDialect>();
1436   target->addIllegalDialect<tfrt::cpu::jit::CpuRuntimeDialect>();
1437   target->addDynamicallyLegalOp<mlir::FuncOp>([func_type_converter,
1438                                                chain_type](FuncOp op) {
1439     auto func_type = op.getType();
1440     if (func_type.getNumInputs() == 0 || func_type.getInput(0) != chain_type)
1441       return false;
1442     if (func_type.getNumResults() == 0 || func_type.getResult(0) != chain_type)
1443       return false;
1444 
1445     return func_type_converter->isSignatureLegal(op.getType()) &&
1446            func_type_converter->isLegal(&op.getBody());
1447   });
1448 }
1449 
1450 // Helper function for inserting TFRT CpuRT dialect conversions.
PopulateCpuRtConversionPatterns(MLIRContext * context,OwningRewritePatternList * patterns,CoreRTConverter * corert_converter)1451 void PopulateCpuRtConversionPatterns(MLIRContext *context,
1452                                      OwningRewritePatternList *patterns,
1453                                      CoreRTConverter *corert_converter) {
1454   // Lower cpurt.call to the pair of compile and execute operations.
1455   patterns->insert<CpuRtCallToCpurtCompileAndExecuteConversion>(context);
1456 }
1457 
1458 // Helper function for inserting TF dialect to TFRT dialect op conversion
1459 // patterns.
PopulateTFToTFRTConversionPatterns(mlir::MLIRContext * context,mlir::OwningRewritePatternList * patterns,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::CostAnalysis * cost_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool enable_native_ops,bool func_use_fallback_tensor,bool tpu_lower_to_fallback)1460 void PopulateTFToTFRTConversionPatterns(
1461     mlir::MLIRContext *context, mlir::OwningRewritePatternList *patterns,
1462     CoreRTConverter *corert_converter,
1463     tfrt_compiler::FallbackConverter *fallback_converter,
1464     mlir::SymbolTable *symbol_table,
1465     const tfrt_compiler::CostAnalysis *cost_analysis,
1466     const tfrt_compiler::TensorArraySideEffectAnalysis
1467         *tensor_array_side_effect_analysis,
1468     bool enable_native_ops, bool func_use_fallback_tensor,
1469     bool tpu_lower_to_fallback) {
1470   // By default, we lower all TF ops to fallback ops.
1471   patterns->insert<FallbackExecuteOpConversion>(
1472       context, corert_converter, fallback_converter, cost_analysis,
1473       tpu_lower_to_fallback);
1474   patterns->insert<FallbackConstOpConversion, FallbackSetResourceOp,
1475                    FallbackGetResourceOp>(context, corert_converter);
1476 
1477   // For control flow ops, we handle them according to the option.
1478   mlir::TypeConverter *func_type_converter;
1479   if (func_use_fallback_tensor) {
1480     func_type_converter = fallback_converter;
1481   } else {
1482     func_type_converter = corert_converter;
1483   }
1484   patterns->insert<TFRTFuncOpSignatureConversion>(context, func_type_converter);
1485   patterns->insert<TFRTReturnOpConversion>(context, corert_converter,
1486                                            func_use_fallback_tensor);
1487   patterns->insert<TFRTWhileOpConversion>(
1488       context, func_type_converter, corert_converter, symbol_table,
1489       tensor_array_side_effect_analysis, func_use_fallback_tensor);
1490   patterns->insert<TFRTCallOpConversion<mlir::TF::StatefulPartitionedCallOp>,
1491                    TFRTCallOpConversion<mlir::TF::PartitionedCallOp>,
1492                    TFRTCaseOpConversion, TFRTCondOpConversion>(
1493       context, func_type_converter, corert_converter, func_use_fallback_tensor);
1494 
1495   // For tf.BatchFunction, we need a special fallback op to batch a BEF
1496   // function.
1497   patterns->insert<FallbackBatchFunctionOpConversion>(context,
1498                                                       corert_converter);
1499 
1500   // Below patterns are preferred over fallback lowering as we want to use
1501   // CoreRT interface for native kernels. This is only temporary and it will
1502   // refactored to use SSOT interface.
1503 
1504   // Here we use specialized patterns for tf.Const on CPU as it is incorrect to
1505   // use ExecuteOp pattern to convert string tensor attribute.
1506   patterns->insert<CoreRTConstStringTensorOpConversion,
1507                    CoreRTConstDenseTensorOpConversion>(context,
1508                                                        corert_converter);
1509 
1510   // For tf.Const op on GPU, we still lower it to corert.executeop temporarily.
1511   //
1512   // TODO(chky): Use specialized patterns for tf.Const ops on GPU.
1513   patterns->insert<CoreRTExecuteOpConversion<TF::ConstOp>>(
1514       context, corert_converter, kGpuDeviceName);
1515 
1516   if (enable_native_ops) {
1517     // Below TF operations will be converted to use corert.executeop, which will
1518     // invoke TFRT native op if implemented.
1519     // TODO(b/187942369): Pattern registration for TF operations is not
1520     // sustainable currently. We need to figure out a plan.
1521     patterns->insert<CoreRTExecuteOpConversion<TF::AddV2Op>,
1522                      // TODO(chky): Move the ReadVariableOp + Identity pattern
1523                      // to optimizer.
1524                      // CoreRTExecuteOpConversion<TF::IdentityOp>,
1525                      CoreRTExecuteOpConversion<TF::MulOp>,
1526                      CoreRTExecuteOpConversion<TF::BiasAddOp>,
1527                      CoreRTExecuteOpConversion<TF::Conv2DOp>,
1528                      CoreRTExecuteOpConversion<TF::ConcatV2Op>,
1529                      CoreRTExecuteOpConversion<TF::CastOp>,
1530                      CoreRTExecuteOpConversion<TF::ExpandDimsOp>,
1531                      CoreRTExecuteOpConversion<TF::TransposeOp>,
1532                      CoreRTExecuteOpConversion<TF::FusedBatchNormV3Op>,
1533                      CoreRTExecuteOpConversion<TF::_FusedBatchNormExOp>,
1534                      CoreRTExecuteOpConversion<TF::LogOp>,
1535                      CoreRTExecuteOpConversion<TF::Log1pOp>,
1536                      CoreRTExecuteOpConversion<TF::LogSoftmaxOp>,
1537                      CoreRTExecuteOpConversion<TF::MatMulOp>,
1538                      CoreRTExecuteOpConversion<TF::_FusedMatMulOp>,
1539                      CoreRTExecuteOpConversion<TF::MaxPoolOp>,
1540                      CoreRTExecuteOpConversion<TF::MeanOp>,
1541                      CoreRTExecuteOpConversion<TF::MulOp>,
1542                      CoreRTExecuteOpConversion<TF::PadOp>,
1543                      CoreRTExecuteOpConversion<TF::RealDivOp>,
1544                      CoreRTExecuteOpConversion<TF::ReluOp>,
1545                      CoreRTExecuteOpConversion<TF::ReshapeOp>,
1546                      CoreRTExecuteOpConversion<TF::RsqrtOp>,
1547                      CoreRTExecuteOpConversion<TF::SoftmaxOp>,
1548                      CoreRTExecuteOpConversion<TF::ShapeOp>,
1549                      CoreRTExecuteOpConversion<TF::SigmoidOp>,
1550                      CoreRTExecuteOpConversion<TF::SubOp>,
1551                      CoreRTExecuteOpConversion<TF::TileOp>,
1552                      CoreRTExecuteOpConversion<TF::TanhOp>,
1553                      CoreRTExecuteOpConversion<TF::ZerosLikeOp>>(
1554         context, corert_converter);
1555   }
1556 }
1557 
1558 // Lower TF dialect MLIR to TFRT dialect.
1559 class TfToTfrtConversionPass
1560     : public mlir::PassWrapper<TfToTfrtConversionPass,
1561                                mlir::OperationPass<mlir::ModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const1562   void getDependentDialects(mlir::DialectRegistry &registry) const override {
1563     getDependentConversionDialects(registry);
1564 
1565     if (target_tpu_) RegisterTPUDialects(&registry);
1566   }
1567 
getArgument() const1568   llvm::StringRef getArgument() const final { return "tf-to-tfrt"; }
getDescription() const1569   llvm::StringRef getDescription() const final {
1570     return "Convert Tensorflow dialect (generated from tf.function) to TFRT "
1571            "dialect.";
1572   }
1573 
1574  public:
1575   TfToTfrtConversionPass() = default;
TfToTfrtConversionPass(const TfrtPipelineOptions & options)1576   explicit TfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1577     target_tpu_ = options.target_tpu;
1578     enable_native_ops_ = options.enable_native_ops;
1579     tpu_use_core_selector_ = options.tpu_use_core_selector;
1580     tpu_use_bundled_transfer_ = options.tpu_use_bundled_transfer;
1581     tpu_lower_to_fallback_ = options.tpu_lower_to_fallback;
1582     tpu_transfer_result_to_host_ = options.tpu_transfer_result_to_host;
1583     cost_threshold_ = options.cost_threshold;
1584     upper_cost_threshold_ = options.upper_cost_threshold;
1585     merge_inter_dependent_streams_ = options.merge_inter_dependent_streams;
1586     func_use_fallback_tensor_ = options.func_use_fallback_tensor;
1587   }
TfToTfrtConversionPass(const TfToTfrtConversionPass &)1588   TfToTfrtConversionPass(const TfToTfrtConversionPass &) {}
1589 
runOnFunction(mlir::FuncOp func,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis & tensor_array_side_effect_analysis,tfrt_compiler::FallbackConverter & fallback_converter,mlir::SymbolTable & symbol_table)1590   mlir::LogicalResult runOnFunction(
1591       mlir::FuncOp func,
1592       const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
1593       const tfrt_compiler::TensorArraySideEffectAnalysis
1594           &tensor_array_side_effect_analysis,
1595       tfrt_compiler::FallbackConverter &fallback_converter,
1596       mlir::SymbolTable &symbol_table) {
1597     auto &context = getContext();
1598     mlir::ConversionTarget target(context);
1599     mlir::OwningRewritePatternList patterns(&getContext());
1600     CoreRTConverter corert_converter(&context, &side_effect_analysis);
1601     tfrt_compiler::CostAnalysis cost_analysis(func);
1602 
1603     if (target_tpu_)
1604       AddTPUTargetDialectAndPatterns(
1605           &target, &patterns, &context, &corert_converter,
1606           TfrtTpuExecuteOpConversionOptions{tpu_use_core_selector_,
1607                                             tpu_use_bundled_transfer_,
1608                                             tpu_transfer_result_to_host_},
1609           tpu_lower_to_fallback_);
1610 
1611     mlir::TypeConverter *func_type_converter;
1612     if (func_use_fallback_tensor_) {
1613       func_type_converter = &fallback_converter;
1614     } else {
1615       func_type_converter = &corert_converter;
1616     }
1617     SetUpTFToTFRTConversionLegality(&target, func_type_converter,
1618                                     corert_converter.chain_type());
1619     PopulateCpuRtConversionPatterns(&context, &patterns, &corert_converter);
1620 
1621     PopulateTFToTFRTConversionPatterns(
1622         &context, &patterns, &corert_converter, &fallback_converter,
1623         &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis,
1624         enable_native_ops_, func_use_fallback_tensor_, tpu_lower_to_fallback_);
1625 
1626     return mlir::applyPartialConversion(func, target, std::move(patterns));
1627   }
1628 
runOnOperation()1629   void runOnOperation() override {
1630     auto module = getOperation();
1631     const auto &side_effect_analysis =
1632         getAnalysis<mlir::TF::SideEffectAnalysis>();
1633 
1634     tfrt_compiler::TensorArraySideEffectAnalysis
1635         tensor_array_side_effect_analysis(module);
1636     tfrt_compiler::FallbackConverter fallback_converter(&getContext());
1637 
1638     mlir::SymbolTable symbol_table(module);
1639 
1640     auto func_op_range = module.getOps<mlir::FuncOp>();
1641     llvm::SmallVector<mlir::FuncOp, 4> func_ops(func_op_range.begin(),
1642                                                 func_op_range.end());
1643     for (auto func : func_ops) {
1644       if (!func.isExternal()) {
1645         if (mlir::failed(runOnFunction(
1646                 func, side_effect_analysis.GetAnalysisForFunc(func),
1647                 tensor_array_side_effect_analysis, fallback_converter,
1648                 symbol_table))) {
1649           signalPassFailure();
1650           return;
1651         }
1652 
1653         ChainDanglingValuesinFunction(func);
1654       }
1655     }
1656 
1657     CreateFallbackInitializationFunction(module, fallback_converter);
1658 
1659     // Set cost threshold as a module attribute. It will be used later in the
1660     // runtime to decide whether certain execution is cheap enough to be inline
1661     // executed.
1662     mlir::Builder builder(module);
1663     module->setAttr("tfrt.cost_threshold",
1664                     builder.getI64IntegerAttr(cost_threshold_));
1665     module->setAttr("tfrt.upper_cost_threshold",
1666                     builder.getI64IntegerAttr(upper_cost_threshold_));
1667     module->setAttr("tfrt.merge_inter_dependent_streams",
1668                     builder.getBoolAttr(merge_inter_dependent_streams_));
1669   }
1670 
1671  private:
1672   // Chain all dangling values (ie. values with no users) together and merge it
1673   // with the first returned chain. This merged chain can be used to signal the
1674   // completion of all execution including side-effets.
ChainDanglingValuesinFunction(mlir::FuncOp func_op)1675   void ChainDanglingValuesinFunction(mlir::FuncOp func_op) {
1676     auto &block = func_op.front();
1677 
1678     llvm::SmallVector<mlir::Value, 2> dangling_values;
1679 
1680     // Find dangling function arguments.
1681     for (auto arg : block.getArguments()) {
1682       if (arg.use_empty()) {
1683         dangling_values.push_back(arg);
1684       }
1685     }
1686 
1687     // Find dangling values produced by ops in the function.
1688     for (auto &op : block) {
1689       for (mlir::Value result : op.getResults()) {
1690         if (result.use_empty()) {
1691           dangling_values.push_back(result);
1692         }
1693       }
1694     }
1695 
1696     if (dangling_values.empty()) return;
1697 
1698     // Merge these dangling values with the first returned chain.
1699     auto return_op =
1700         llvm::cast<tfrt::compiler::ReturnOp>(block.getTerminator());
1701     auto chain = return_op->getOperand(0);
1702     assert(chain.getType().isa<tfrt::compiler::ChainType>());
1703     dangling_values.push_back(chain);
1704 
1705     mlir::OpBuilder builder(return_op);
1706 
1707     auto new_chain = builder.create<tfrt::compiler::MergeChainsOp>(
1708         return_op->getLoc(), builder.getType<tfrt::compiler::ChainType>(),
1709         dangling_values);
1710 
1711     return_op->setOperand(0, new_chain);
1712   }
1713 
CreateFallbackInitializationFunction(mlir::ModuleOp module,tfrt_compiler::FallbackConverter & fallback_converter)1714   void CreateFallbackInitializationFunction(
1715       mlir::ModuleOp module,
1716       tfrt_compiler::FallbackConverter &fallback_converter) {
1717     mlir::OpBuilder builder(&module.body());
1718 
1719     auto chain_type = builder.getType<tfrt::compiler::ChainType>();
1720 
1721     auto func_op = builder.create<mlir::FuncOp>(
1722         module.getLoc(), "_tfrt_fallback_init",
1723         mlir::FunctionType::get(module.getContext(), /*inputs=*/chain_type,
1724                                 /*results=*/chain_type));
1725 
1726     auto *block = func_op.addEntryBlock();
1727     builder.setInsertionPointToStart(block);
1728 
1729     mlir::Value chain_value = block->getArgument(0);
1730 
1731     // Create operations for all fallback kernels in the module.
1732     for (auto *op : fallback_converter.GetFallbackOps()) {
1733       auto create_op = builder.create<tfrt::fallback_async::CreateOp>(
1734           func_op.getLoc(), chain_type, chain_value);
1735 
1736       create_op->setAttrs(op->getAttrs());
1737       create_op->setAttr("num_args", builder.getI64IntegerAttr(GetNumArgs(op)));
1738 
1739       chain_value = create_op;
1740     }
1741 
1742     // Pre-compile all JIT compiled kernels found in the module.
1743     llvm::SmallVector<Value> compiled;
1744 
1745     // A set SymbolRef attributes referencing compiled kernels.
1746     llvm::DenseSet<mlir::Attribute> kernels;
1747 
1748     // Compile all kernels in parallell.
1749     module.walk([&](tf_cpurt::FallbackExecuteOp execute) {
1750       // Do not compiled the same kernel multiple times.
1751       if (kernels.contains(execute.kernel())) return;
1752 
1753       auto compile = builder.create<tf_cpurt::FallbackCompileOp>(
1754           execute.getLoc(), chain_type, execute.kernel(), execute.device());
1755       compiled.push_back(compile.getResult());
1756       kernels.insert(compile.kernel());
1757     });
1758 
1759     // Wait for the compilation completion before returning from init function.
1760     if (compiled.size() == 1) {
1761       chain_value = compiled[0];
1762     } else if (compiled.size() > 1) {
1763       chain_value = builder.create<tfrt::compiler::MergeChainsOp>(
1764           func_op.getLoc(), chain_type, compiled);
1765     }
1766 
1767     builder.create<tfrt::compiler::ReturnOp>(func_op.getLoc(), chain_value);
1768   }
1769 
GetNumArgs(mlir::Operation * fallback_op)1770   int64_t GetNumArgs(mlir::Operation *fallback_op) {
1771     if (auto execute_op =
1772             llvm::dyn_cast<tfrt::fallback_async::ExecuteOp>(fallback_op)) {
1773       return execute_op.operands().size();
1774     } else if (auto execute_op_seq =
1775                    llvm::dyn_cast<tfrt::fallback_async::ExecuteOpSeq>(
1776                        fallback_op)) {
1777       return execute_op_seq.operands().size();
1778     }
1779     llvm_unreachable("invalid fallback op type");
1780   }
1781 
1782   Option<bool> target_tpu_{*this, "target-tpu",
1783                            llvm::cl::desc("Target TPU programs if true."),
1784                            llvm::cl::init(false)};
1785 
1786   Option<bool> enable_native_ops_{
1787       *this, "enable-native-ops",
1788       llvm::cl::desc(
1789           "If true, native ops will be used on an opt-in basis "
1790           "instead of fallback ops. If false, no native ops are used."),
1791       llvm::cl::init(true)};
1792 
1793   Option<bool> tpu_use_core_selector_{
1794       *this, "tpu-use-core-selector",
1795       llvm::cl::desc("If true, use ServingCoreSelector to pick TPU core. "
1796                      "Otherwise, use the assigned core."),
1797       llvm::cl::init(true)};
1798 
1799   Option<bool> tpu_use_bundled_transfer_{
1800       *this, "tpu-use-bundled-transfer",
1801       llvm::cl::desc("If true, use BundledTransferToTpuOp to transfer "
1802                      "variables and input tensors to TPU."),
1803       llvm::cl::init(true)};
1804 
1805   Option<bool> tpu_lower_to_fallback_{
1806       *this, "tpu-lower-to-fallback",
1807       llvm::cl::desc("If true, lower an TF op that's placed on TPU device "
1808                      "to be executed by tfrt_fallback.execute."),
1809       llvm::cl::init(true)};
1810 
1811   // TODO(b/194081364): remove this option once we unify servo TPU serving
1812   // result transfer behavior.
1813   Option<bool> tpu_transfer_result_to_host_{
1814       *this, "tpu-transfer-result-to-host",
1815       llvm::cl::desc("If true, transfer the result of tpurt.execute from TPU "
1816                      "to host."),
1817       llvm::cl::init(true)};
1818 
1819   Option<uint64_t> cost_threshold_{
1820       *this, "tfrt-cost-threshold",
1821       llvm::cl::desc(
1822           "The cost threshold to decide whether a sequence of operations is "
1823           "cheap, and then whether it can be executed inline."),
1824       llvm::cl::init(1)};
1825 
1826   Option<int64_t> upper_cost_threshold_{
1827       *this, "tfrt-upper-cost-threshold",
1828       llvm::cl::desc(
1829           "The threshold to limit the merging of dependent sequence."),
1830       llvm::cl::init(-1)};
1831 
1832   Option<bool> merge_inter_dependent_streams_{
1833       *this, "tfrt-merge-inter-dependent-streams",
1834       llvm::cl::desc("If true, streams with inter data depenedencies will be "
1835                      "preferred to be merged for inline execution."),
1836       llvm::cl::init(false)};
1837 
1838   Option<bool> func_use_fallback_tensor_{
1839       *this, "func-use-fallback-tensor",
1840       llvm::cl::desc(
1841           "If true, use TF tensor as input/output types in func (and other "
1842           "control flow) ops."),
1843       llvm::cl::init(false)};
1844 };
1845 
1846 }  // namespace
1847 
1848 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTfToTfrtConversionPass(const TfrtPipelineOptions & options)1849 CreateTfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1850   return std::make_unique<TfToTfrtConversionPass>(options);
1851 }
1852 
1853 // -------------------------------------------------------------------------- //
1854 // Outline tf_device.cluster operation regions into functions in the nested
1855 // modules and replaces all cluster operations with cpurt.call operations.
1856 // -------------------------------------------------------------------------- //
1857 
1858 class OutlineCpuRtClustersPass
1859     : public PassWrapper<OutlineCpuRtClustersPass, OperationPass<ModuleOp>> {
1860  public:
getArgument() const1861   llvm::StringRef getArgument() const final {
1862     return "tf-outline-cpurt-cluster";
1863   }
getDescription() const1864   llvm::StringRef getDescription() const final {
1865     return "Outlines `tf_device.cluster` operations into functions and "
1866            "replaces them with `cpurt.call` operations.";
1867   }
1868 
1869   void runOnOperation() override;
1870 
getDependentDialects(mlir::DialectRegistry & registry) const1871   void getDependentDialects(mlir::DialectRegistry &registry) const override {
1872     registry.insert<tfrt::cpu::jit::CpuRuntimeDialect>();
1873   }
1874 
1875  private:
1876   struct CompiledModule {
1877     ModuleOp module;
1878     FuncOp entrypoint;
1879     llvm::SetVector<Value> operands;
1880   };
1881 
1882   // Creates a nested module with a single function that will be compiled into
1883   // the kernel at runtime.
1884   CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster,
1885                                       SymbolTable *symbol_table);
1886 
1887   // Update compiled module entrypoint signature with inferred operands
1888   // constraints.
1889   LogicalResult SetEntrypointConstraints(CompiledModule &compiled);
1890 
1891   // Outlines cluster operation regions into compiled modules, and replaces
1892   // cluster operation with a cpurt.call operation.
1893   LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster,
1894                                  SymbolTable *symbol_table);
1895 
1896   // Mapping from the outlined module string representation to the module itself
1897   // and an entrypoint function. Used to deduplicate identical modules during
1898   // the `tf_device.cluster` outlining.
1899   llvm::StringMap<std::pair<ModuleOp, FuncOp>> outlined_;
1900 };
1901 
1902 OutlineCpuRtClustersPass::CompiledModule
CreateCompiledModule(tf_device::ClusterOp cluster,SymbolTable * symbol_table)1903 OutlineCpuRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster,
1904                                                SymbolTable *symbol_table) {
1905   MLIRContext *ctx = cluster->getContext();
1906   Location loc = cluster.getLoc();
1907 
1908   // Create a module that will hold compiled function and async wrappers.
1909   // TODO(ezhulenev): Give better names to module and function.
1910   auto compiled_module = ModuleOp::create(loc, {"kernel"});
1911   compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx));
1912 
1913   SymbolTable compiled_module_symbol_table(compiled_module);
1914 
1915   // Find out the cluster arguments and their types.
1916   llvm::SetVector<Value> live_ins;
1917   getUsedValuesDefinedAbove(cluster.body(), cluster.body(), live_ins);
1918 
1919   llvm::SmallVector<Type, 4> operand_types;
1920   operand_types.reserve(live_ins.size());
1921   for (Value v : live_ins) operand_types.emplace_back(v.getType());
1922 
1923   // Create a function in the compiled module.
1924   auto compiled_func_type =
1925       FunctionType::get(ctx, operand_types, cluster->getResultTypes());
1926   auto compiled_func = FuncOp::create(loc, "compute", compiled_func_type);
1927   compiled_module_symbol_table.insert(compiled_func);
1928 
1929   // Replace uses of live-in values within cluster region with block arguments.
1930   Block *compiled_func_block = compiled_func.addEntryBlock();
1931   for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments()))
1932     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), cluster.body());
1933 
1934   // Move all operations in cluster into compiled_func's entry block.
1935   auto &cluster_body = cluster.GetBody().getOperations();
1936   compiled_func_block->getOperations().splice(
1937       compiled_func_block->end(), cluster_body, cluster_body.begin(),
1938       cluster_body.end());
1939 
1940   // Replace `tf_device.return` terminator with `return` in the function body.
1941   auto device_return =
1942       cast<tf_device::ReturnOp>(compiled_func_block->getTerminator());
1943   OpBuilder builder(device_return.getOperation());
1944   builder.create<ReturnOp>(device_return.getLoc(), device_return.getOperands());
1945   device_return.erase();
1946 
1947   // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet,
1948   // replace module printing with a more principled solution when available.
1949   // Operations in the cluster can be in different order, however define the
1950   // identical Tensorflow programs, with current approach we'll not be able
1951   // to detect duplicates like this.
1952 
1953   // Serialize prepared module to string.
1954   std::string serialized;
1955   llvm::raw_string_ostream os(serialized);
1956   compiled_module.print(os);
1957 
1958   // Try to find if identical module was already outlined.
1959   auto it = outlined_.find(serialized);
1960 
1961   // Return identical module that was already outlined earlier.
1962   if (it != outlined_.end()) {
1963     compiled_module.erase();  // erase identical module
1964     return {it->second.first, it->second.second, live_ins};
1965   }
1966 
1967   // Insert compiled module into the symbol table and assign it a unique name.
1968   symbol_table->insert(compiled_module);
1969 
1970   // Cache unique module.
1971   outlined_.insert({std::move(serialized), {compiled_module, compiled_func}});
1972 
1973   return {compiled_module, compiled_func, live_ins};
1974 }
1975 
SetEntrypointConstraints(CompiledModule & compiled)1976 LogicalResult OutlineCpuRtClustersPass::SetEntrypointConstraints(
1977     CompiledModule &compiled) {
1978   FuncOp func = compiled.entrypoint;
1979 
1980   // Functions outlined from cpurt device clusters must have a single block.
1981   assert(func.body().getBlocks().size() == 1 && "expected single block");
1982 
1983   mlir::TFDevice::ClusteringPolicySet policies;
1984   populateTfCpurtConstraintsPolicies(policies);
1985 
1986   // Infer constraints on the values defined in the entrypoint function
1987   // (including function entry block arguments).
1988   mlir::TFDevice::ValuesConstraintSet constraints;
1989   if (failed(mlir::TFDevice::PropagateValuesConstraints(
1990           func.body(), policies, constraints, /*resolve=*/true)))
1991     return failure();
1992 
1993   // Annotate arguments with inferred constraints.
1994   for (unsigned i = 0; i < func.getNumArguments(); ++i) {
1995     if (auto constraint = constraints.GetConstraint(func.getArgument(i))) {
1996       auto constraint_name = mlir::StringAttr::get(
1997           &getContext(), llvm::formatv("{0}", *constraint).str());
1998       func.setArgAttr(i, "cpurt.constraint", constraint_name);
1999     }
2000   }
2001 
2002   return success();
2003 }
2004 
OutlineClusterOp(tf_device::ClusterOp cluster,SymbolTable * symbol_table)2005 LogicalResult OutlineCpuRtClustersPass::OutlineClusterOp(
2006     tf_device::ClusterOp cluster, SymbolTable *symbol_table) {
2007   Location loc = cluster->getLoc();
2008   OpBuilder builder(cluster);
2009 
2010   CompiledModule compiled_module = CreateCompiledModule(cluster, symbol_table);
2011   FuncOp compiled_func = compiled_module.entrypoint;
2012 
2013   // Add constraints to the entrypoint arguments.
2014   if (failed(SetEntrypointConstraints(compiled_module))) return failure();
2015 
2016   // Replace device cluster with a cpurt.call operation.
2017   auto module_name = *compiled_module.module.sym_name();
2018   auto func_name = compiled_func.sym_name();
2019   auto func_flat_ref = builder.getSymbolRefAttr(func_name);
2020   auto func_ref = builder.getSymbolRefAttr(module_name, {func_flat_ref});
2021 
2022   auto cluster_func_op = builder.create<tfrt::cpu::jit::CallOp>(
2023       loc, cluster.getResultTypes(), func_ref,
2024       compiled_module.operands.getArrayRef());
2025 
2026   cluster.replaceAllUsesWith(cluster_func_op);
2027   cluster.erase();
2028 
2029   return success();
2030 }
2031 
runOnOperation()2032 void OutlineCpuRtClustersPass::runOnOperation() {
2033   ModuleOp module = getOperation();
2034   SymbolTable symbol_table(module);
2035 
2036   OpBuilder builder(module.getContext());
2037   auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult {
2038     // Ensure that cluster was formed for TFRT JIT compilation.
2039     auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2040     if (!policy || policy.getValue() != "tfrt.auto-fusion")
2041       return WalkResult::advance();
2042 
2043     if (failed(OutlineClusterOp(cluster, &symbol_table)))
2044       return WalkResult::interrupt();
2045     return WalkResult::advance();
2046   });
2047 
2048   if (result.wasInterrupted()) {
2049     module->emitError("Failed to outline tf_device.cluster operations");
2050     signalPassFailure();
2051   }
2052 }
2053 
CreateOutlineCpuRtClustersPass()2054 std::unique_ptr<Pass> CreateOutlineCpuRtClustersPass() {
2055   return std::make_unique<OutlineCpuRtClustersPass>();
2056 }
2057 
2058 // -------------------------------------------------------------------------- //
2059 
CreateTFExecutorToTFPipeline(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2060 void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm,
2061                                   const TfrtPipelineOptions &options) {
2062   // Due to b/191304670, functionalized while ops might not have the
2063   // shape_invariant attribute set correctly, which leads to failure in shape
2064   // inference. As a workaround, we conservatively (e.g., we place less
2065   // restrictions on tf.while which will avoid failures but lead to potentially
2066   // less exact shape inference) set the shape_invariant attribute in all
2067   // tf.While ops before performing shape inference.
2068   //
2069   // Note that this pass might not work well with TF XLA bridge, but this is
2070   // fine as TF XLA bridge is run before this pipeline. For CPU ops, less exact
2071   // shape inference may lead to fewer optimizations but it should be fine as it
2072   // is limited to while ops currently.
2073   //
2074   // TODO(b/191304670): Remove this pass once the shape_invariant attribute is
2075   // set correctly in the upstream.
2076   pm.addNestedPass<mlir::FuncOp>(
2077       tfrt_compiler::CreateSetShapeInvariantInWhileOps());
2078 
2079   // We pass the MLIR module through the TF standard pipeline, which for
2080   // instances does shape inference, canonicalization, inlining, etc.
2081   mlir::TF::StandardPipelineOptions tf_options;
2082   tf_options.enable_inliner = true;
2083   mlir::TF::CreateTFStandardPipeline(pm, tf_options);
2084 
2085   pm.addPass(mlir::TF::CreateConstantOpDeviceAssignmentPass());
2086 
2087   // After the standard pass, we now have MLIR in TF dialect, and now we convert
2088   // reference variable to resource variables, which is besteffort.
2089   pm.addPass(CreateConvertReferenceVariableToResourceVariablePass());
2090 
2091   // Move the tf.Assert op to the end of the function, so that it does not
2092   // impose unnecessary control dependencies on other ops.
2093   pm.addPass(tfrt_compiler::CreateReorderTfAssertPass());
2094 
2095   // Optimze the side-effects of control flow ops by examining the ops in its
2096   // callees.
2097   pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass());
2098 
2099   // Remove tf.If ops' operands that are produced by tf.Const ops.
2100   pm.addPass(tfrt_compiler::CreateRemoveTfIfConstArgsPass());
2101 
2102   // Merge non-side-effecting tf.If ops if their operands are the same.
2103   pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass());
2104 
2105   // Deduplicate functions invoked by tf.BatchFunction with the same
2106   // shared_name
2107   pm.addPass(
2108       tfrt_compiler::CreateDeduplicateFunctionsInovkedByBatchFunctionPass());
2109 
2110   // Apply standard optimization after optimizing control flow ops.
2111   pm.addPass(mlir::createInlinerPass());
2112   pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
2113 
2114   // TODO(b/187876545): An extra shape inference pass is added because it does
2115   // not work well with tf.Identity op that remove ref type. So we work around
2116   // by performing shape inference again after reference variable to resource
2117   // variable conversion. We should remove this after b/187876545 is fixed.
2118   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
2119 
2120   pm.addNestedPass<mlir::FuncOp>(
2121       mlir::TFDevice::CreateLaunchToDeviceAttributePass());
2122 
2123   // After all standard passes run layout optimization to assign optimal data
2124   // format for all layout sensitive operations.
2125   mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
2126   layout_optimization_options.force_data_format =
2127       options.force_data_format.getValue();
2128   // TODO(b/191304261): Folding transpose in ops is buggy in the layout
2129   // optimization pass. Disable it to avoid errors in b/191304261. This should
2130   // not affect CPU performance as it does not change the number of ops, nor
2131   // does it change the types of the ops.
2132   layout_optimization_options.skip_fold_transpose_in_ops = true;
2133   mlir::TF::CreateLayoutOptimizationPipeline(pm.nest<mlir::FuncOp>(),
2134                                              layout_optimization_options);
2135 
2136   // Run canonicalization pipeline to remove unused constants and bypassed
2137   // transpose operations left in the IR after layout optimization.
2138   pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
2139 
2140   // Decompose resource ops as resource variables will be converted to tensors
2141   // directly. Only do use for non-TPU programs.
2142   if (options.decompose_resource_ops && !options.target_tpu)
2143     pm.addNestedPass<mlir::FuncOp>(
2144         mlir::TFDevice::CreateDecomposeResourceOpsPass());
2145 
2146   // Then we assign default devices.
2147   pm.addNestedPass<mlir::FuncOp>(
2148       mlir::TF::CreateSimpleTFDeviceAssignmentPass(options.default_device));
2149 
2150   pm.addNestedPass<mlir::FuncOp>(
2151       mlir::TF::CreateTensorDeviceCopyConversionPass());
2152 
2153   // Outline auto-fusion clusters into tf_device.cluster_operations and then
2154   // convert them to functions. We currently support only tfrt fallback tensors
2155   // as operands, so we disable these passes if we can have native ops after
2156   // lowering.
2157   if (!options.enable_native_ops) {
2158     pm.addNestedPass<mlir::FuncOp>(CreateTfCpurtClusteringPass(
2159         options.auto_fusion_oplist, options.auto_fusion_min_cluster_size));
2160 
2161     // Sink small constants into the outlined clusters to reduce the number of
2162     // arguments for each of the execute operations.
2163     auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster,
2164                                   mlir::ElementsAttr value) -> bool {
2165       // Ensure that cluster was formed for TFRT JIT compilation.
2166       auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2167       if (!policy || policy.getValue() != "tfrt.auto-fusion") return false;
2168 
2169       // Check that TF->CPURT compiler supports constant compilation.
2170       return mlir::succeeded(IsCompilableConstant(value));
2171     };
2172 
2173     pm.addNestedPass<mlir::FuncOp>(
2174         mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const));
2175 
2176     // Outline formed JIT compiled device clusters into function.
2177     pm.addPass(CreateOutlineCpuRtClustersPass());
2178   }
2179 
2180   // Rewriter operation sequences to device specific fusions.
2181   DeviceNameUtils::ParsedName parsed_name;
2182 
2183   // Ignore error.
2184   bool success =
2185       DeviceNameUtils::ParseFullName(options.default_device, &parsed_name);
2186   assert(success && "default device is invalid");
2187   (void)success;
2188 
2189   if (parsed_name.has_type && parsed_name.type == DEVICE_GPU)
2190     pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateGpuOpFusionPass());
2191 
2192   if (parsed_name.has_type && parsed_name.type == DEVICE_CPU)
2193     pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateFusedKernelMatcherPass());
2194 
2195   pm.addPass(CreateLowerTFSavedModelPass(options.hoist_invariant_ops));
2196 }
2197 
CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2198 void CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager &pm,
2199                                           const TfrtPipelineOptions &options) {
2200   CreateTFExecutorToTFPipeline(pm, options);
2201 
2202   pm.addPass(CreateTfToTfrtConversionPass(options));
2203 
2204   pm.addPass(CreateRemoveDeviceAttributePass());
2205 
2206   // Run optimizer on the MLIR module in CoreRT dialect.
2207   if (options.enable_optimizer) {
2208     pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
2209     pm.addPass(mlir::createInlinerPass());
2210     pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
2211     pm.addNestedPass<mlir::FuncOp>(
2212         tfrt_compiler::CreateInsertFallbackTensorCopyPass());
2213   }
2214 }
2215 
2216 // If verbose logging is on, dump the output of each pass to a file directory,
2217 // set via env var TF_DUMP_GRAPH_PREFIX. e.g.:
2218 // export TF_DUMP_GRAPH_PREFIX=/tmp/mlir
CreateTfExecutorToTfrtPipeline(mlir::PassManager & pm,const TfrtPipelineOptions & options)2219 void CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm,
2220                                     const TfrtPipelineOptions &options) {
2221   if (VLOG_IS_ON(1)) {
2222     // Print the whole module after each pass, which requires disabling
2223     // multi-threading as well.
2224     pm.getContext()->disableMultithreading();
2225     pm.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
2226         /*print_module_scope=*/true));
2227   }
2228   CreateTfExecutorToTfrtPipelineHelper(pm, options);
2229 }
2230 
2231 static mlir::PassRegistration<TfToTfrtConversionPass> tf_to_tfrt_pass;
2232 
2233 static mlir::PassRegistration<OutlineCpuRtClustersPass>
2234     tf_outline_cpurt_cluster_pass(CreateOutlineCpuRtClustersPass);
2235 
2236 static mlir::PassPipelineRegistration<TfrtPipelineOptions> tf_pipeline(
2237     "tf-executor-to-tfrt-pipeline",
2238     "Convert Tensorflow Executor dialect to TFRT dialect and "
2239     "also apply necessary optimization passes.",
2240     CreateTfExecutorToTfrtPipelineHelper);
2241 
2242 }  // namespace tensorflow
2243