1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements lowering of TF dialect to TFRT CoreRuntime ExecuteOp.
17 // This lowering pass is heavily experimental and incomplete. External code
18 // should not depend on the code here. And please do not take example on it as
19 // "the path forward" for this.
20
21 #include <vector>
22
23 #include "mlir/Dialect/StandardOps/IR/Ops.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Dialect.h"
27 #include "mlir/IR/OperationSupport.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Pass/PassManager.h"
30 #include "mlir/Pass/PassOptions.h"
31 #include "mlir/Transforms/DialectConversion.h"
32 #include "mlir/Transforms/Passes.h"
33 #include "mlir/Transforms/RegionUtils.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/Support/FormatVariadic.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
39 #include "mlir/Support/LogicalResult.h" // from @llvm-project
40 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
45 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
46 #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h"
47 #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h"
48 #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_cpurt_ops.h"
49 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.h"
50 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_passes.h"
51 #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
52 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
53 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
54 #include "tensorflow/core/framework/tensor.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/platform/tstring.h"
57 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback.h"
58 #include "tensorflow/core/runtime_fallback/opdefs/tfrt_fallback_async.h"
59 #include "tfrt/cpu/jit/opdefs/cpurt_ops.h" // from @tf_runtime
60 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
61 #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime
62 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
63 #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime
64 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
65 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
66 #include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
67 #include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime
68 #include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime
69
70 namespace tensorflow {
71 namespace {
72
73 constexpr unsigned kFallbackBenefit = 1;
74 constexpr unsigned kCoreRTBenefit = 2;
75 constexpr char kGpuDeviceName[] =
76 "/job:localhost/replica:0/task:0/device:GPU:0";
77 constexpr char kCpuDeviceName[] =
78 "/job:localhost/replica:0/task:0/device:CPU:0";
79 constexpr char kTFDeviceAttr[] = "tf.device";
80 constexpr char kTFRTDeviceAttr[] = "tfrt.device";
81 constexpr char kDeviceAttr[] = "device";
82 constexpr char kHostAttr[] = "host";
83 constexpr char kDeviceTypeTpu[] = "TPU";
84
getDependentConversionDialects(mlir::DialectRegistry & registry)85 void getDependentConversionDialects(mlir::DialectRegistry ®istry) {
86 registry.insert<tfrt::corert::CoreRTDialect,
87 tfrt::fallback_async::FallbackAsyncDialect,
88 tfrt::compiler::TFRTDialect, tfrt::dist::DistributedDialect,
89 tf_cpurt::CpuRuntimeDialect>();
90 }
91
GetFunctionInputChain(mlir::Operation * op)92 mlir::Value GetFunctionInputChain(mlir::Operation *op) {
93 auto func_op = op->getParentOfType<mlir::FuncOp>();
94 return func_op.getArgument(0);
95 }
96
97 // Convert TF dialect ops to tfrt_fallback.executeop for non-side-effecting ops
98 // and tfrt_fallback.executeop.seq for side-effecting ops.
99 //
100 // For example,
101 //
102 // %0 = "tf.MatMul"(%arg0, %arg1) {device = "cpu"} :
103 // (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
104 //
105 // is converted to
106 //
107 // %result = tfrt_fallback.executeop device("cpu")
108 // "tf.MatMul"(%arg0, %arg1) : 1
109 //
110 class FallbackExecuteOpConversion : public mlir::ConversionPattern {
111 public:
FallbackExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,const tfrt_compiler::CostAnalysis * cost_analysis,bool tpu_lower_to_fallback)112 FallbackExecuteOpConversion(
113 mlir::MLIRContext *context, CoreRTConverter *corert_converter,
114 tfrt_compiler::FallbackConverter *fallback_converter,
115 const tfrt_compiler::CostAnalysis *cost_analysis,
116 bool tpu_lower_to_fallback)
117 : mlir::ConversionPattern(mlir::Pattern::MatchAnyOpTypeTag(),
118 kFallbackBenefit, context),
119 corert_converter_(*corert_converter),
120 fallback_converter_(*fallback_converter),
121 cost_analysis_(*cost_analysis),
122 tpu_lower_to_fallback_(tpu_lower_to_fallback) {}
123
matchAndRewrite(mlir::Operation * op,ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const124 LogicalResult matchAndRewrite(
125 mlir::Operation *op, ArrayRef<mlir::Value> operands,
126 mlir::ConversionPatternRewriter &rewriter) const override {
127 if (!UseFallback(op)) return failure();
128
129 corert_converter_.MaterializeDerivedAttributes(op);
130
131 mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
132 if (!device || device.getValue().empty())
133 return op->emitWarning("failed to find a non-empty 'device' attribute");
134 op->removeAttr(kDeviceAttr);
135 auto parsed_device_name =
136 corert_converter_.ParseDeviceName(device.getValue());
137 if (!parsed_device_name)
138 return op->emitWarning("failed to parse the device name");
139
140 // Convert the function (symbol) attributes to an array of string
141 // attributes, which represents the function names.
142 llvm::SmallVector<mlir::Identifier, 4> func_attr_keys;
143 mlir::ArrayAttr op_func_attrs =
144 corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
145
146 // Remove the function attributes, which have already been processed.
147 for (const auto &key : func_attr_keys) op->removeAttr(key);
148
149 mlir::ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
150 if (!op_attrs) return op->emitWarning("failed to lower attributes.");
151
152 mlir::StringAttr op_name =
153 rewriter.getStringAttr(op->getName().getStringRef());
154
155 // Currently the fallback executeop does not support GPU device. For GPU
156 // device, we still lower it to corert executeop where more devices are
157 // supported.
158 //
159 // TODO(b/166705169): Once GPU device are supported in fallback, we can
160 // remove the conversion to corert.
161 if (parsed_device_name->device_type == DEVICE_GPU ||
162 (parsed_device_name->device_type == kDeviceTypeTpu &&
163 !tpu_lower_to_fallback_)) {
164 return ConvertToCoreRTExecuteOp(
165 op, operands, parsed_device_name->op_handler_name, op_attrs,
166 op_func_attrs, op_name, rewriter);
167 }
168
169 // For DEVICE_CPU, DEVICE_DEFAULT, and DEVICE_TPU_SYSTEM, we use fallback.
170 // Note that TPU bridge should handle all ops that are required to be
171 // executed on TPU. So if there are still ops that are placed on TPU at this
172 // stage of lowering TF to TFRT, then these ops are supposed to be executed
173 // on host.
174 return ConvertToFallbackExecuteOp(op, operands, device, op_attrs,
175 op_func_attrs, op_name,
176 fallback_converter_, rewriter);
177 }
178
179 private:
180 // Return true if this op can be lowered to fallback ops.
UseFallback(mlir::Operation * op) const181 bool UseFallback(mlir::Operation *op) const {
182 if (!llvm::isa<mlir::TF::TensorFlowDialect>(op->getDialect())) return false;
183
184 // Below is a blocklist of ops that should not go through CoreRTExecuteOp
185 // conversion.
186 // TODO(b/173017701): have a centralized place to hold the information
187 // whether a TF op should be lowered to FallbackExecute op.
188 return !llvm::isa<
189 mlir::TF::_TfrtSetResourceOp, mlir::TF::_TfrtGetResourceOp,
190 mlir::TF::_TPUCompileMlirOp, mlir::TF::TPUCompileSucceededAssertOp,
191 mlir::TF::TPUExecuteOp,
192 // Specifically handle control flow ops.
193 mlir::TF::CaseOp, mlir::TF::IfOp, mlir::TF::WhileOp,
194 mlir::TF::StatefulPartitionedCallOp, mlir::TF::PartitionedCallOp>(op);
195 }
196
197 mlir::LogicalResult ConvertToFallbackExecuteOp(
198 mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
199 mlir::StringAttr device, mlir::ArrayAttr op_attrs,
200 mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
201 tfrt_compiler::FallbackConverter &fallback_converter,
202 mlir::ConversionPatternRewriter &rewriter) const;
203
204 mlir::LogicalResult ConvertToCoreRTExecuteOp(
205 mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
206 llvm::StringRef op_handler_name, mlir::ArrayAttr op_attrs,
207 mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
208 mlir::ConversionPatternRewriter &rewriter) const;
209
210 CoreRTConverter &corert_converter_;
211 tfrt_compiler::FallbackConverter &fallback_converter_;
212 const tfrt_compiler::CostAnalysis &cost_analysis_;
213 bool tpu_lower_to_fallback_;
214 };
215
ConvertToFallbackExecuteOp(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,mlir::StringAttr device,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,tfrt_compiler::FallbackConverter & fallback_converter,mlir::ConversionPatternRewriter & rewriter) const216 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp(
217 mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
218 mlir::StringAttr device, mlir::ArrayAttr op_attrs,
219 mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
220 tfrt_compiler::FallbackConverter &fallback_converter,
221 mlir::ConversionPatternRewriter &rewriter) const {
222 llvm::SmallVector<Type, 4> result_types(
223 op->getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
224
225 // Convert the operands to tfrt_fallback.tf_tensor if needed.
226 llvm::SmallVector<mlir::Value, 4> new_operands;
227 if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
228 op, device.getValue(), operands, &new_operands, rewriter)))
229 return failure();
230
231 auto fallback_key =
232 rewriter.getI64IntegerAttr(fallback_converter.GetNextFallbackKey());
233
234 // Query cost analysis to assign costs.
235 auto cost = rewriter.getI64IntegerAttr(cost_analysis_.GetCost(op));
236
237 if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
238 auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
239 op->getLoc(), result_types, new_operands, device, op_attrs,
240 op_func_attrs, fallback_key, op_name, cost);
241 fallback_converter.RegisterFallbackOp(new_op);
242 rewriter.replaceOp(op, new_op.results());
243 } else {
244 auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
245 auto out_chain = in_chain;
246
247 if (tfrt_compiler::IsTensorArrayOp(op)) {
248 // If it is a tensor array op, we don't need to use
249 // tfrt_fallback_async.executeop.seq because its operands/results already
250 // take care of control dependencies.
251 auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOp>(
252 op->getLoc(), result_types, new_operands, device, op_attrs,
253 op_func_attrs, fallback_key, op_name, cost);
254 fallback_converter.RegisterFallbackOp(new_op);
255 rewriter.replaceOp(op, new_op.results());
256 } else {
257 // Create tfrt_fallback.executeop.seq if it is a side-effecting op.
258 auto new_op = rewriter.create<tfrt::fallback_async::ExecuteOpSeq>(
259 op->getLoc(), corert_converter_.chain_type(), result_types, in_chain,
260 new_operands, device, op_attrs, op_func_attrs, fallback_key, op_name,
261 cost);
262 fallback_converter.RegisterFallbackOp(new_op);
263 rewriter.replaceOp(op, new_op.results());
264 out_chain = new_op.out_op_chain();
265 }
266
267 // Register the converted op so that it can be retrieved by successors.
268 corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
269 }
270
271 return success();
272 }
273
ConvertToCoreRTExecuteOp(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,llvm::StringRef op_handler_name,mlir::ArrayAttr op_attrs,mlir::ArrayAttr op_func_attrs,mlir::StringAttr op_name,mlir::ConversionPatternRewriter & rewriter) const274 mlir::LogicalResult FallbackExecuteOpConversion::ConvertToCoreRTExecuteOp(
275 mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
276 llvm::StringRef op_handler_name, mlir::ArrayAttr op_attrs,
277 mlir::ArrayAttr op_func_attrs, mlir::StringAttr op_name,
278 mlir::ConversionPatternRewriter &rewriter) const {
279 llvm::SmallVector<Type, 4> result_types(
280 op->getNumResults(), rewriter.getType<tfrt::corert::TensorHandleType>());
281
282 // Convert the operands to tensorhandles if needed.
283 llvm::SmallVector<mlir::Value, 4> new_operands;
284 if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
285 op, operands, &new_operands, rewriter)))
286 return failure();
287
288 // Get the op handler, or create one if there does not exist one. Note that
289 // ConvertOpHandler changes internal state so it can only be called if the
290 // rewrite is guaranteed to succeed afterwards.
291 auto op_handler =
292 corert_converter_.ConvertOpHandler(op, op_handler_name, &rewriter);
293 if (!op_handler) return failure();
294
295 if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
296 auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
297 op->getLoc(), result_types, op_handler, new_operands, op_attrs,
298 op_func_attrs, op_name);
299 rewriter.replaceOp(op, new_op.results());
300 } else {
301 // Create corert.executeop.seq if it is a side-effecting op.
302 auto new_op = rewriter.create<tfrt::corert::ExecuteOpSeq>(
303 op->getLoc(), corert_converter_.chain_type(), result_types, op_handler,
304 corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
305 op_attrs, op_func_attrs, op_name);
306 rewriter.replaceOp(op, new_op.results());
307
308 // Register the converted op so that it can be retrieved by successors.
309 corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
310 }
311
312 return success();
313 }
314
315 class FallbackConstOpConversion
316 : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
317 public:
FallbackConstOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)318 FallbackConstOpConversion(mlir::MLIRContext *context,
319 CoreRTConverter *corert_converter)
320 : mlir::OpConversionPattern<mlir::TF::ConstOp>(context),
321 corert_converter_(*corert_converter) {}
322
matchAndRewrite(mlir::TF::ConstOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const323 mlir::LogicalResult matchAndRewrite(
324 mlir::TF::ConstOp op, llvm::ArrayRef<mlir::Value> operands,
325 mlir::ConversionPatternRewriter &rewriter) const override {
326 // Some data types are handled separately using a fast path.
327 if (corert_converter_.IsSupportedNumericDType(op.dtype()) ||
328 op.dtype().isa<mlir::TF::StringType>())
329 return failure();
330
331 // For other data types that do not have a fast path (eg. quantized types),
332 // we convert it to serialized tensor proto.
333
334 tensorflow::TensorProto tensor_proto;
335 auto status = ConvertToTensorProto(op.value(), &tensor_proto);
336 if (!status.ok()) return op.emitError(status.error_message());
337
338 rewriter.replaceOpWithNewOp<tfrt::fallback_async::ConstTensorProtoOp>(
339 op, rewriter.getType<tfrt::fallback::TFTensorType>(),
340 tensor_proto.SerializeAsString());
341
342 return success();
343 }
344
345 private:
346 CoreRTConverter &corert_converter_;
347 };
348
349 class FallbackSetResourceOp
350 : public mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp> {
351 public:
FallbackSetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)352 FallbackSetResourceOp(mlir::MLIRContext *context,
353 CoreRTConverter *corert_converter)
354 : mlir::OpConversionPattern<mlir::TF::_TfrtSetResourceOp>(context),
355 corert_converter_(*corert_converter) {}
356
matchAndRewrite(mlir::TF::_TfrtSetResourceOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const357 mlir::LogicalResult matchAndRewrite(
358 mlir::TF::_TfrtSetResourceOp op, llvm::ArrayRef<mlir::Value> operands,
359 mlir::ConversionPatternRewriter &rewriter) const override {
360 mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
361 if (!device || device.getValue().empty())
362 return op->emitWarning("failed to find a non-empty 'device' attribute");
363
364 // Currently the static resource is always on host CPU.
365 //
366 // TODO(chky): Support resource on other devices.
367 llvm::SmallVector<mlir::Value, 4> new_operands;
368 if (mlir::failed(tfrt_compiler::ConvertFallbackOperands(
369 op, kCpuDeviceName, operands, &new_operands, rewriter)))
370 return failure();
371
372 assert(new_operands.size() == 1);
373
374 auto new_op = rewriter.create<tfrt::fallback_async::SetResourceOp>(
375 op.getLoc(), corert_converter_.chain_type(),
376 corert_converter_.GetLocalSideEffectChain(op, &rewriter),
377 new_operands[0], device.getValue(), op.index());
378
379 // Register the converted op so that it can be retrieved by successors.
380 corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_ch());
381
382 rewriter.eraseOp(op);
383
384 return success();
385 }
386
387 private:
388 CoreRTConverter &corert_converter_;
389 };
390
391 class FallbackGetResourceOp
392 : public mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp> {
393 public:
FallbackGetResourceOp(mlir::MLIRContext * context,CoreRTConverter * corert_converter)394 FallbackGetResourceOp(mlir::MLIRContext *context,
395 CoreRTConverter *corert_converter)
396 : mlir::OpConversionPattern<mlir::TF::_TfrtGetResourceOp>(context),
397 corert_converter_(*corert_converter) {}
398
matchAndRewrite(mlir::TF::_TfrtGetResourceOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const399 mlir::LogicalResult matchAndRewrite(
400 mlir::TF::_TfrtGetResourceOp op, llvm::ArrayRef<mlir::Value> operands,
401 mlir::ConversionPatternRewriter &rewriter) const override {
402 mlir::StringAttr device = op->getAttrOfType<mlir::StringAttr>(kDeviceAttr);
403 if (!device || device.getValue().empty())
404 return op->emitWarning("failed to find a non-empty 'device' attribute");
405
406 llvm::SmallVector<mlir::Type, 4> result_types(
407 op.getNumResults(), rewriter.getType<tfrt::fallback::TFTensorType>());
408
409 auto new_op = rewriter.create<tfrt::fallback_async::GetResourceOp>(
410 op.getLoc(), corert_converter_.chain_type(), result_types,
411 corert_converter_.GetLocalSideEffectChain(op, &rewriter),
412 device.getValue(), op.indices());
413
414 // Register the converted op so that it can be retrieved by successors.
415 corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_ch());
416
417 rewriter.replaceOp(op, new_op.results());
418
419 return success();
420 }
421
422 private:
423 CoreRTConverter &corert_converter_;
424 };
425
426 // Convert a tf_device.remote_run op to a tfrt_dist.remote_execute_func op.
427 //
428 // For example,
429 //
430 // %result = tf_device.remote_run "/job:worker/replica:0/task:0"
431 // @remote_func(%arg0) : (tensor<i32>) -> (tensor<i32>)
432 //
433 // is converted to
434 //
435 // %0 = tfrt_dist.get_distributed_context
436 // %1 = tfrt_dist.get_task_handle %0
437 // {task_name = "/job:worker/replica:0/task:0"}
438 // %2 = tfrt_dist.test_create_remote_chain_manager %0
439 // %3 = tfrt_dist.get_chain_for_task_handle %in_chain, %2, %1
440 // %out_op_chain, %results:2 = tfrt_dist.remote_execute_func[%in_chain, %0, %1]
441 // @remote_func(%3, %arg1): (!tfrt_dist.remote_object_id, !corert.tensorhandle)
442 // -> (!tfrt_dist.remote_object_id, !corert.tensorhandle)
443 // %4 = tfrt_dist.set_chain_for_task_handle %out_op_chain, %2, %1, %results#0
444 //
445 class TFDeviceRemoteRunOpConversion
446 : public mlir::OpConversionPattern<tf_device::RemoteRunOp> {
447 public:
TFDeviceRemoteRunOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter)448 TFDeviceRemoteRunOpConversion(mlir::MLIRContext *context,
449 mlir::TypeConverter *type_converter,
450 CoreRTConverter *corert_converter)
451 : mlir::OpConversionPattern<tf_device::RemoteRunOp>(context,
452 kCoreRTBenefit),
453 type_converter_(*type_converter),
454 corert_converter_(*corert_converter) {}
455
matchAndRewrite(tf_device::RemoteRunOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const456 LogicalResult matchAndRewrite(
457 tf_device::RemoteRunOp op, ArrayRef<mlir::Value> operands,
458 ConversionPatternRewriter &rewriter) const override {
459 mlir::Value distributed_context =
460 corert_converter_.GetDistributedContext(op.getOperation(), &rewriter);
461 mlir::Value in_op_chain =
462 corert_converter_.GetLocalSideEffectChain(op, &rewriter);
463 mlir::Value task_handle = corert_converter_.GetTaskHandle(
464 op.getOperation(), op.host(), &rewriter);
465 mlir::Value remote_chain_mgr =
466 corert_converter_.GetRemoteChainManager(op, &rewriter);
467 mlir::Type remote_obj_id_ty =
468 rewriter.getType<tfrt::dist::RemoteObjectIdType>();
469 ModuleOp module = op->getParentOfType<ModuleOp>();
470 SymbolTable symtab(module);
471 FuncOp callee = symtab.lookup<FuncOp>(op.callee());
472 if (!callee) {
473 op.emitOpError("callee function ") << op.callee() << " is not found";
474 return failure();
475 }
476 StringAttr host = callee->getAttrOfType<StringAttr>(kHostAttr);
477 if (!host) {
478 op.emitOpError("callee function ")
479 << op.callee() << " should have the host attribute";
480 return failure();
481 }
482
483 llvm::SmallVector<mlir::Value, 4> arguments;
484 // The first argument of the remote function should be a remote chain which
485 // is added to the function signature when it is lowered from TF dialect to
486 // TFRT dialect.
487 arguments.push_back(corert_converter_.GetRemoteSideEffectChain(
488 op, host.getValue(), &rewriter));
489 for (mlir::Value argument : op.callee_args()) {
490 arguments.push_back(argument);
491 }
492
493 llvm::SmallVector<mlir::Type, 4> result_types;
494 // The first result of the remote function should be a remote chain which
495 // is added to the function signature when it is lowered from TF dialect to
496 // TFRT dialect.
497 result_types.push_back(remote_obj_id_ty);
498 for (mlir::Type type : op.getResultTypes()) {
499 (void)type_converter_.convertType(type, result_types);
500 }
501 auto remote_execute_func_op =
502 rewriter.create<tfrt::dist::RemoteExecuteFuncOp>(
503 op.getLoc(), corert_converter_.chain_type(), result_types,
504 in_op_chain, distributed_context, task_handle, op.callee(),
505 arguments);
506 rewriter.replaceOp(op, remote_execute_func_op.results().drop_front(1));
507
508 auto set_chain_op = rewriter.create<tfrt::dist::SetChainForTaskHandleOp>(
509 op.getLoc(), corert_converter_.chain_type(),
510 remote_execute_func_op.out_op_chain(), remote_chain_mgr, task_handle,
511 remote_execute_func_op.results().front());
512 corert_converter_.RegisterLocalSideEffectChain(op,
513 set_chain_op.out_op_chain());
514
515 return success();
516 }
517
518 private:
519 mlir::TypeConverter &type_converter_;
520 CoreRTConverter &corert_converter_;
521 };
522
523 // Lowers a tf.BatchFunction to tfrt_fallback.batch_function.
524 class FallbackBatchFunctionOpConversion
525 : public mlir::OpConversionPattern<mlir::TF::BatchFunctionOp> {
526 public:
FallbackBatchFunctionOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)527 FallbackBatchFunctionOpConversion(mlir::MLIRContext *context,
528 CoreRTConverter *corert_converter)
529 : mlir::OpConversionPattern<mlir::TF::BatchFunctionOp>(context,
530 kFallbackBenefit),
531 corert_converter_(*corert_converter) {}
532
matchAndRewrite(mlir::TF::BatchFunctionOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const533 LogicalResult matchAndRewrite(
534 mlir::TF::BatchFunctionOp op, ArrayRef<mlir::Value> operands,
535 ConversionPatternRewriter &rewriter) const override {
536 corert_converter_.MaterializeDerivedAttributes(op);
537
538 // Remove the device attribute for fallback, as currently fallback will
539 // select device automatically.
540 //
541 // TODO(chky): The device attribute should be passed explicitly. This can be
542 // once we change the kernel implementation to choose device based on
543 // attributes.
544 op->removeAttr(rewriter.getIdentifier(kDeviceAttr));
545
546 SymbolRefAttr f = op.fAttr();
547
548 llvm::SmallVector<NamedAttribute, 12> attr_array;
549 for (auto &key_and_value : op->getAttrs()) {
550 StringRef name = key_and_value.first;
551 if (name == "_output_shapes" || name == "f") {
552 continue;
553 }
554 attr_array.push_back(key_and_value);
555 }
556 ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(attr_array);
557 if (!op_attrs) return op.emitWarning("failed to lower attributes.");
558
559 llvm::SmallVector<Type, 4> result_types;
560 for (auto type : op.getResultTypes()) {
561 if (failed(corert_converter_.convertType(type, result_types)))
562 return failure();
563 }
564
565 llvm::SmallVector<mlir::Value, 4> new_operands;
566 if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
567 op, operands, &new_operands, rewriter)))
568 return failure();
569
570 auto new_op = rewriter.create<tfrt::fallback_async::BatchFunctionOp>(
571 op.getLoc(), corert_converter_.chain_type(), result_types,
572 corert_converter_.GetLocalSideEffectChain(op, &rewriter), new_operands,
573 f, op_attrs);
574 rewriter.replaceOp(op, new_op.results());
575
576 // Register the converted op so that it can be retrieved by successors.
577 corert_converter_.RegisterLocalSideEffectChain(op, new_op.out_op_chain());
578
579 return success();
580 }
581
582 private:
583 CoreRTConverter &corert_converter_;
584 };
585
586 // Lower a tf.Const op that creates a string tensor to a native
587 // corert.create_string_tensor op.
588 class CoreRTConstDenseTensorOpConversion
589 : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
590 public:
CoreRTConstDenseTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)591 CoreRTConstDenseTensorOpConversion(mlir::MLIRContext *context,
592 CoreRTConverter *corert_converter)
593 : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
594 corert_converter_(*corert_converter) {}
595
matchAndRewrite(mlir::TF::ConstOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const596 LogicalResult matchAndRewrite(
597 mlir::TF::ConstOp op, ArrayRef<mlir::Value> operands,
598 ConversionPatternRewriter &rewriter) const override {
599 if (!corert_converter_.IsSupportedNumericDType(op.dtype()))
600 return failure();
601
602 // Only CPU ops can be lowered using this conversion. If there is no device
603 // assignment, this op is treated as a CPU op and can be lowered.
604 if (auto parsed_device_name = corert_converter_.ParseDeviceName(op))
605 if (parsed_device_name->device_type != DEVICE_CPU) return failure();
606
607 auto new_op = rewriter.create<tfrt::corert::ConstDenseTensorOp>(
608 op.getLoc(), corert_converter_.tensor_handle_type(),
609 op.value().cast<DenseElementsAttr>());
610 rewriter.replaceOp(op, new_op->getResult(0));
611 return success();
612 }
613
614 private:
615 CoreRTConverter &corert_converter_;
616 };
617
618 // Convert the FuncOp with the following changes to meet TFRT's requirements:
619 // 1) Convert types for the arguments and the results.
620 // 2) Add a chain to the arguments.
621 // 3) Add a chain to the results for side-effects.
622 // 4) If any argument has a tf.device attribute, change the attribute name
623 // to tfrt.device.
624 // 5) If any result has a tf.device attribute, change the attribute name
625 // to tfrt.device.
626 //
627 // The input chain is used to signal visibility of all side-effects before
628 // calling this function. The output chain is used to signal visibility of all
629 // side-effects of this function.
630 class TFRTFuncOpSignatureConversion
631 : public mlir::OpConversionPattern<mlir::FuncOp> {
632 public:
TFRTFuncOpSignatureConversion(mlir::MLIRContext * ctx,mlir::TypeConverter * type_converter)633 TFRTFuncOpSignatureConversion(mlir::MLIRContext *ctx,
634 mlir::TypeConverter *type_converter)
635 : OpConversionPattern(ctx), type_converter_(*type_converter) {}
636
matchAndRewrite(mlir::FuncOp func_op,llvm::ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const637 LogicalResult matchAndRewrite(
638 mlir::FuncOp func_op, llvm::ArrayRef<Value> operands,
639 ConversionPatternRewriter &rewriter) const override {
640 mlir::FunctionType type = func_op.getType();
641
642 // Convert the original function arguments.
643 mlir::TypeConverter::SignatureConversion converted_signature(
644 type.getNumInputs());
645 // Add a chain as the first input.
646 converted_signature.addInputs(
647 rewriter.getType<tfrt::compiler::ChainType>());
648
649 // Convert the original function results.
650 SmallVector<Type, 1> converted_results;
651 // Add a chain as the first result.
652 converted_results.push_back(rewriter.getType<tfrt::compiler::ChainType>());
653
654 if (failed(type_converter_.convertSignatureArgs(type.getInputs(),
655 converted_signature)) ||
656 failed(type_converter_.convertTypes(type.getResults(),
657 converted_results)) ||
658 failed(rewriter.convertRegionTypes(&func_op.getBody(), type_converter_,
659 &converted_signature))) {
660 return failure();
661 }
662
663 llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs;
664 // The first input, which is a chain added by this pass, has no attribute.
665 arg_attrs.emplace_back();
666 func_op.getAllArgAttrs(arg_attrs);
667 // If any argument has a tf.device attribute, change the attribute name to
668 // tfrt.device.
669 for (auto &arg_attr : arg_attrs) {
670 mlir::NamedAttrList arg_attr_list(arg_attr);
671 if (Attribute device = arg_attr_list.erase(kTFDeviceAttr)) {
672 arg_attr_list.set(kTFRTDeviceAttr, device);
673 arg_attr = arg_attr_list.getDictionary(device.getContext());
674 }
675 }
676 arg_attrs.resize(converted_signature.getConvertedTypes().size());
677
678 // The first result, which is a chain added by this pass, has no attribute.
679 llvm::SmallVector<mlir::DictionaryAttr, 4> result_attrs;
680 result_attrs.emplace_back();
681 func_op.getAllResultAttrs(result_attrs);
682 // If any result has a tf.device attribute, change the attribute name to
683 // tfrt.device.
684 for (auto &result_attr : result_attrs) {
685 mlir::NamedAttrList result_attr_list(result_attr);
686 if (Attribute device = result_attr_list.erase(kTFDeviceAttr)) {
687 result_attr_list.set(kTFRTDeviceAttr, device);
688 result_attr = result_attr_list.getDictionary(device.getContext());
689 }
690 }
691 result_attrs.resize(converted_results.size());
692
693 // Update the function signature in-place.
694 rewriter.updateRootInPlace(func_op, [&] {
695 func_op.setType(mlir::FunctionType::get(
696 func_op.getContext(), converted_signature.getConvertedTypes(),
697 converted_results));
698 func_op.setAllArgAttrs(arg_attrs);
699 func_op.setAllResultAttrs(result_attrs);
700 });
701
702 return success();
703 }
704
705 private:
706 mlir::TypeConverter &type_converter_;
707 };
708
709 // Lower a tf.Const op that creates a string tensor to a native
710 // corert.create_string_tensor op.
711 class CoreRTConstStringTensorOpConversion
712 : public mlir::OpConversionPattern<mlir::TF::ConstOp> {
713 public:
CoreRTConstStringTensorOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)714 CoreRTConstStringTensorOpConversion(mlir::MLIRContext *context,
715 CoreRTConverter *corert_converter)
716 : mlir::OpConversionPattern<mlir::TF::ConstOp>(context, kCoreRTBenefit),
717 corert_converter_(*corert_converter) {}
718
matchAndRewrite(mlir::TF::ConstOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const719 LogicalResult matchAndRewrite(
720 mlir::TF::ConstOp op, ArrayRef<mlir::Value> operands,
721 ConversionPatternRewriter &rewriter) const override { // NOLINT
722 if (!op.dtype().isa<mlir::TF::StringType>()) return failure();
723
724 DenseStringElementsAttr attr = op.value().cast<DenseStringElementsAttr>();
725
726 llvm::SmallVector<Attribute, 4> values;
727 values.reserve(attr.getNumElements());
728 for (const auto &element : attr.getRawStringData())
729 values.push_back(rewriter.getStringAttr(
730 llvm::StringRef(element.data(), element.size())));
731
732 // Create the shape attribute from the tensor shape.
733 ArrayRef<int64_t> shape = op.value().getType().getShape();
734 llvm::SmallVector<mlir::Attribute, 4> dims;
735 dims.reserve(shape.size());
736 auto i64_type = rewriter.getIntegerType(64);
737 for (auto dim : shape)
738 dims.push_back(rewriter.getIntegerAttr(i64_type, dim));
739
740 auto new_op = rewriter.create<tfrt::corert::ConstStringTensorOp>(
741 op.getLoc(), corert_converter_.tensor_handle_type(),
742 rewriter.getArrayAttr(dims), rewriter.getArrayAttr(values));
743
744 rewriter.replaceOp(op, new_op.result());
745
746 return success();
747 }
748
749 private:
750 CoreRTConverter &corert_converter_;
751 };
752
753 // Convert TF dialect operations with no side effects to CoreRT ExecuteOp. For
754 // example,
755 //
756 // %0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
757 // (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
758 //
759 // is converted to
760 //
761 // %result = corert.executeop(%device)
762 // "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} :
763 // (!corert.tensorhandle, !corert.tensorhandle) -> !corert.tensorhandle
764 //
765 // Note that it will fail to match if some attributes are not supported.
766 template <typename TF_Op>
767 class CoreRTExecuteOpConversion : public mlir::OpConversionPattern<TF_Op> {
768 public:
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter)769 CoreRTExecuteOpConversion(mlir::MLIRContext *context,
770 CoreRTConverter *corert_converter)
771 : CoreRTExecuteOpConversion(context, corert_converter, "") {}
772
773 // If device_name is not empty, only ops that are using this device is lowered
774 // using CoreRTExecuteOpConversion.
CoreRTExecuteOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,llvm::StringRef device_name)775 CoreRTExecuteOpConversion(mlir::MLIRContext *context,
776 CoreRTConverter *corert_converter,
777 llvm::StringRef device_name)
778 : mlir::OpConversionPattern<TF_Op>(context, kCoreRTBenefit),
779 corert_converter_(*corert_converter),
780 device_name_(device_name) {}
781
matchAndRewrite(TF_Op op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const782 LogicalResult matchAndRewrite(
783 TF_Op op, ArrayRef<mlir::Value> operands,
784 ConversionPatternRewriter &rewriter) const override {
785 auto parsed_device_name = corert_converter_.ParseDeviceName(op);
786 // Return failure and emit warning if there is no device assignment.
787 if (!parsed_device_name) {
788 return op->emitWarning(
789 "failed to retrieve valid device when converting to "
790 "corert.executeop");
791 }
792
793 // If device_name is specified, check the device of this op first.
794 if (!device_name_.empty()) {
795 // Skip if it does not match the specified device.
796 if (parsed_device_name->device_name != device_name_) return failure();
797 }
798
799 mlir::StringAttr op_name = rewriter.getStringAttr(op.getOperationName());
800
801 llvm::SmallVector<Type, 4> result_types;
802 for (auto type : op.getOperation()->getResultTypes()) {
803 if (failed(corert_converter_.convertType(type, result_types)))
804 return failure();
805 }
806
807 corert_converter_.MaterializeDerivedAttributes(op);
808
809 // Convert the function (symbol) attributes to an array of string
810 // attributes, which represents the function names.
811 llvm::SmallVector<mlir::Identifier, 4> func_attr_keys;
812 ArrayAttr op_func_attrs =
813 corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys);
814
815 // Remove the function attributes, which have already been processed.
816 for (const auto &key : func_attr_keys) op->removeAttr(key);
817
818 ArrayAttr op_attrs = corert_converter_.CreateOpAttrs(op->getAttrs());
819 if (!op_attrs) return op.emitError("failed to lower attributes.");
820
821 llvm::SmallVector<mlir::Value, 4> new_operands;
822 if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands(
823 op, operands, &new_operands, rewriter)))
824 return failure();
825
826 // Get the op handler, or create one if there does not exist one. Note that
827 // ConvertOpHandler changes internal state so it can only be called if the
828 // rewrite is guaranteed to succeed afterwards.
829 auto op_handler = corert_converter_.ConvertOpHandler(
830 op, parsed_device_name->op_handler_name, &rewriter);
831 if (!op_handler) return failure();
832
833 auto new_op = rewriter.create<tfrt::corert::ExecuteOp>(
834 op.getLoc(), result_types, op_handler, new_operands, op_attrs,
835 op_func_attrs, op_name);
836
837 rewriter.replaceOp(op, new_op.results());
838 return success();
839 }
840
841 private:
842 CoreRTConverter &corert_converter_;
843 llvm::StringRef device_name_;
844 };
845
ConvertFunctionCallOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter,bool func_use_fallback_tensor)846 LogicalResult ConvertFunctionCallOperands(
847 mlir::Operation *op, llvm::ArrayRef<mlir::Value> operands,
848 llvm::SmallVectorImpl<mlir::Value> *new_operands,
849 mlir::ConversionPatternRewriter &rewriter, bool func_use_fallback_tensor) {
850 if (func_use_fallback_tensor) {
851 // TODO(b/182232457): Support other devices.
852 return tfrt_compiler::ConvertFallbackOperands(op, kCpuDeviceName, operands,
853 new_operands, rewriter);
854 } else {
855 return tfrt_compiler::ConvertCoreRTOperands(op, operands, new_operands,
856 rewriter);
857 }
858 }
859
860 // Convert TF call ops (eg. StatefulPartitionedCall) to tfrt.call.
861 template <typename CallOp>
862 class TFRTCallOpConversion : public mlir::OpConversionPattern<CallOp> {
863 public:
TFRTCallOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)864 TFRTCallOpConversion(mlir::MLIRContext *context,
865 mlir::TypeConverter *type_converter,
866 CoreRTConverter *corert_converter,
867 bool func_use_fallback_tensor)
868 : mlir::OpConversionPattern<CallOp>(context),
869 type_converter_(*type_converter),
870 corert_converter_(*corert_converter),
871 func_use_fallback_tensor_(func_use_fallback_tensor) {}
872
matchAndRewrite(CallOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const873 LogicalResult matchAndRewrite(
874 CallOp op, ArrayRef<mlir::Value> operands,
875 ConversionPatternRewriter &rewriter) const override {
876 auto callee =
877 op.getCallableForCallee().template dyn_cast<mlir::SymbolRefAttr>();
878 if (!callee) return failure();
879
880 llvm::SmallVector<mlir::Value, 4> new_operands;
881 new_operands.push_back(
882 corert_converter_.GetLocalSideEffectChain(op, &rewriter));
883 // Currently the converted functions always use !corert.tensorhandle types
884 // for tensor arguments and results.
885 //
886 // TODO(b/175881042): We should avoid the tensor conversion here if the
887 // operand is !tfrt_fallback.tf_tensor, and it is also used as fallback
888 // tensor inside the callee function.
889 if (mlir::failed(ConvertFunctionCallOperands(
890 op, operands, &new_operands, rewriter, func_use_fallback_tensor_)))
891 return failure();
892
893 llvm::SmallVector<mlir::Type, 4> result_types;
894 result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
895 for (auto type : op.getOperation()->getResultTypes()) {
896 if (failed(type_converter_.convertType(type, result_types)))
897 return failure();
898 }
899
900 auto new_op = rewriter.create<tfrt::compiler::CallOp>(
901 op.getLoc(), result_types, callee.getRootReference(), new_operands);
902 rewriter.replaceOp(op, new_op.getResults().drop_front());
903
904 if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
905 // Register the converted op so that it can be retrieved by successors.
906 // TODO(chky): Add OpTraits or OpInterface, rather than assume first
907 // result is a chain.
908 corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
909 }
910
911 return success();
912 }
913
914 private:
915 mlir::TypeConverter &type_converter_;
916 CoreRTConverter &corert_converter_;
917 bool func_use_fallback_tensor_;
918 };
919
920 // Convert standard ReturnOp to tfrt.return.
921 //
922 // TODO(chky): conversion to tfrt kernels should come from a common tf_to_tfrt
923 // library.
924 class TFRTReturnOpConversion
925 : public mlir::OpConversionPattern<mlir::ReturnOp> {
926 public:
TFRTReturnOpConversion(mlir::MLIRContext * context,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)927 TFRTReturnOpConversion(mlir::MLIRContext *context,
928 CoreRTConverter *corert_converter,
929 bool func_use_fallback_tensor)
930 : mlir::OpConversionPattern<mlir::ReturnOp>(context),
931 corert_converter_(*corert_converter),
932 func_use_fallback_tensor_(func_use_fallback_tensor) {}
933
matchAndRewrite(mlir::ReturnOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const934 LogicalResult matchAndRewrite(
935 mlir::ReturnOp op, ArrayRef<mlir::Value> operands,
936 ConversionPatternRewriter &rewriter) const override {
937 llvm::SmallVector<mlir::Value, 2> new_operands;
938
939 // Currently in mlir::TF::SideEffectAnalysis, all terminator ops are treated
940 // as side-effect ops and they have predecessors but not successors.
941 //
942 // TODO(chky): ReturnOp has no side effect. Once the special handling in
943 // mlir::TF::SideEffectAnalysis is removed, the chains should come from
944 // side-effecting ops with no successors in the function.
945 new_operands.push_back(
946 corert_converter_.GetLocalSideEffectChain(op, &rewriter));
947 if (mlir::failed(ConvertFunctionCallOperands(
948 op, operands, &new_operands, rewriter, func_use_fallback_tensor_)))
949 return failure();
950
951 rewriter.replaceOpWithNewOp<tfrt::compiler::ReturnOp>(op, new_operands);
952 return success();
953 }
954
955 private:
956 CoreRTConverter &corert_converter_;
957 bool func_use_fallback_tensor_;
958 };
959
960 // Convert tf.Case op to tfrt.Case.
961 //
962 // TF dialect:
963 // %outputs = "tf.Case"(%arg, ...) { branches = [@branch0, @branch1], ...}
964 //
965 // lowered TFRT CoreRT dialect:
966 // %idx_int = corert.tensorhandle_to_int32 %idx
967 // %out_chain, %outputs = tfrt.case %idx_int [@branch0, @branch1] (%chain, ...)
968 class TFRTCaseOpConversion : public mlir::OpConversionPattern<TF::CaseOp> {
969 public:
TFRTCaseOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)970 TFRTCaseOpConversion(mlir::MLIRContext *context,
971 mlir::TypeConverter *type_converter,
972 CoreRTConverter *corert_converter,
973 bool func_use_fallback_tensor)
974 : mlir::OpConversionPattern<TF::CaseOp>(context),
975 type_converter_(*type_converter),
976 corert_converter_(*corert_converter),
977 func_use_fallback_tensor_(func_use_fallback_tensor) {}
978
matchAndRewrite(TF::CaseOp op,ArrayRef<mlir::Value> operands,ConversionPatternRewriter & rewriter) const979 LogicalResult matchAndRewrite(
980 TF::CaseOp op, ArrayRef<mlir::Value> operands,
981 ConversionPatternRewriter &rewriter) const override {
982 mlir::ArrayAttr branches = op.branches();
983
984 llvm::SmallVector<mlir::Type, 4> result_types;
985 for (mlir::Type type : op->getResultTypes()) {
986 if (failed(type_converter_.convertType(type, result_types)))
987 return failure();
988 }
989
990 llvm::SmallVector<mlir::Value, 4> branch_operands;
991 if (mlir::failed(ConvertFunctionCallOperands(op, operands.drop_front(),
992 &branch_operands, rewriter,
993 func_use_fallback_tensor_)))
994 return failure();
995
996 mlir::Value index_operand = operands[0];
997 // TODO(b/182233401): Support TF tensor; remove the conversion op here.
998 if (index_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
999 // TODO(b/182232457): Support other devices.
1000 index_operand =
1001 rewriter
1002 .create<
1003 tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
1004 op.getLoc(),
1005 rewriter.getType<tfrt::corert::TensorHandleType>(),
1006 operands[0], kCpuDeviceName)
1007 .getResult(0);
1008 }
1009 if (!index_operand.getType().isa<tfrt::corert::TensorHandleType>())
1010 return op.emitError(
1011 "branch index operand is expected to be a TensorHandle.");
1012 mlir::Value index_value =
1013 rewriter.create<tfrt::corert::TensorHandleToIntOp>(
1014 op.getLoc(), rewriter.getI32Type(), index_operand);
1015
1016 auto new_op = rewriter.create<tfrt::compiler::CaseOp>(
1017 op.getLoc(), corert_converter_.chain_type(), result_types, index_value,
1018 branches, corert_converter_.GetLocalSideEffectChain(op, &rewriter),
1019 branch_operands);
1020
1021 rewriter.replaceOp(op, new_op.branch_outputs());
1022 return success();
1023 }
1024
1025 private:
1026 mlir::TypeConverter &type_converter_;
1027 CoreRTConverter &corert_converter_;
1028 bool func_use_fallback_tensor_;
1029 };
1030
GetPredicate(mlir::Operation * op,mlir::Value cond_operand,mlir::ConversionPatternRewriter & rewriter)1031 static mlir::Value GetPredicate(mlir::Operation *op, mlir::Value cond_operand,
1032 mlir::ConversionPatternRewriter &rewriter) {
1033 if (!cond_operand.getType().isa<tfrt::fallback::TFTensorType>()) {
1034 cond_operand = tfrt_compiler::ConvertCoreRTTensorHandleToFallbackTensor(
1035 op->getLoc(), kCpuDeviceName, cond_operand, rewriter);
1036 if (!cond_operand) {
1037 op->emitWarning("failed to convert the cond operand to fallback tensor.");
1038 return {};
1039 }
1040 }
1041
1042 return rewriter.create<tfrt::fallback_async::PredicateOp>(
1043 op->getLoc(), rewriter.getI1Type(), cond_operand);
1044 }
1045
1046 class TFRTCondOpConversion : public mlir::OpConversionPattern<mlir::TF::IfOp> {
1047 public:
TFRTCondOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,bool func_use_fallback_tensor)1048 TFRTCondOpConversion(mlir::MLIRContext *context,
1049 mlir::TypeConverter *type_converter,
1050 CoreRTConverter *corert_converter,
1051 bool func_use_fallback_tensor)
1052 : mlir::OpConversionPattern<TF::IfOp>(context),
1053 type_converter_(*type_converter),
1054 corert_converter_(*corert_converter),
1055 func_use_fallback_tensor_(func_use_fallback_tensor) {}
1056
matchAndRewrite(mlir::TF::IfOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const1057 mlir::LogicalResult matchAndRewrite(
1058 mlir::TF::IfOp op, llvm::ArrayRef<mlir::Value> operands,
1059 mlir::ConversionPatternRewriter &rewriter) const override {
1060 mlir::FlatSymbolRefAttr then_branch = op.then_branchAttr();
1061 mlir::FlatSymbolRefAttr else_branch = op.else_branchAttr();
1062
1063 llvm::SmallVector<Type, 4> result_types;
1064 result_types.push_back(rewriter.getType<tfrt::compiler::ChainType>());
1065 for (mlir::Type type : op.getOperation()->getResultTypes()) {
1066 if (failed(type_converter_.convertType(type, result_types)))
1067 return failure();
1068 }
1069
1070 // Convert the cond tensor to a boolean value so that it can be used by
1071 // tfrt.cond kernel.
1072 auto bool_cond = GetPredicate(op, operands[0], rewriter);
1073 if (!bool_cond) return failure();
1074
1075 llvm::SmallVector<mlir::Value, 4> new_operands;
1076 // The first arg of converted branch functions should be !tfrt.chain.
1077 new_operands.push_back(
1078 corert_converter_.GetLocalSideEffectChain(op, &rewriter));
1079
1080 if (mlir::failed(ConvertFunctionCallOperands(op, operands.drop_front(),
1081 &new_operands, rewriter,
1082 func_use_fallback_tensor_)))
1083 return failure();
1084
1085 auto new_op = rewriter.create<tfrt::compiler::CondOp>(
1086 op.getLoc(), result_types, bool_cond, then_branch, else_branch,
1087 new_operands);
1088
1089 // The first result is a !tfrt.chain.
1090 rewriter.replaceOp(op, new_op.getResults().drop_front(1));
1091
1092 if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1093 // Register the converted op so that it can be retrieved by successors.
1094 // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1095 // result is a chain.
1096 corert_converter_.RegisterLocalSideEffectChain(op, new_op.getResult(0));
1097 }
1098
1099 return success();
1100 }
1101
1102 private:
1103 mlir::TypeConverter &type_converter_;
1104 CoreRTConverter &corert_converter_;
1105 bool func_use_fallback_tensor_;
1106 };
1107
1108 // Convert TF WhileOp to tfrt.while. tfrt.while use a boolean condition and has
1109 // slightly different semantics from tf.While for performance and generality.
1110 // The pseudo code of tfrt.while is as follows:
1111 //
1112 // while(cond) {
1113 // outputs, cond = body(inputs)
1114 // inputs = outputs
1115 // }
1116 // return outputs
1117 //
1118 // So we need to insert extra convertion kernels and merge functions when
1119 // lowering tf.While to tfrt.while.
1120 //
1121 // %result = tf.While(%arg) {cond = @original_cond_fn, body =
1122 // @original_body_fn}
1123 //
1124 // is converted to
1125 //
1126 // func @new_pred_fn(%ch, %arg) {
1127 // %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1128 // %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1129 // tfrt.return %ch0, %cond_bool
1130 // }
1131 //
1132 // func @new_while_body(%ch, %arg) {
1133 // %ch0, %result = tfrt.call @original_body_fn(%ch, %arg)
1134 // %ch1, %cond_bool = tfrt.call @new_pred_fn(%ch0, %result)
1135 // tfrt.return %ch1, %result, %cond_bool
1136 // }
1137 //
1138 // %ch0, %first_iter_cond = tfrt.call @new_pred_fn(%ch, %arg)
1139 // %ch1, %result = tfrt.while %first_iter_cond @new_while_body(%ch0, %arg)
1140 //
1141 class TFRTWhileOpConversion
1142 : public mlir::OpConversionPattern<mlir::TF::WhileOp> {
1143 public:
TFRTWhileOpConversion(mlir::MLIRContext * context,mlir::TypeConverter * type_converter,CoreRTConverter * corert_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool func_use_fallback_tensor)1144 TFRTWhileOpConversion(mlir::MLIRContext *context,
1145 mlir::TypeConverter *type_converter,
1146 CoreRTConverter *corert_converter,
1147 mlir::SymbolTable *symbol_table,
1148 const tfrt_compiler::TensorArraySideEffectAnalysis
1149 *tensor_array_side_effect_analysis,
1150 bool func_use_fallback_tensor)
1151 : mlir::OpConversionPattern<TF::WhileOp>(context),
1152 type_converter_(*type_converter),
1153 corert_converter_(*corert_converter),
1154 symbol_table_(*symbol_table),
1155 tensor_array_side_effect_analysis_(*tensor_array_side_effect_analysis),
1156 func_use_fallback_tensor_(func_use_fallback_tensor) {}
1157
matchAndRewrite(mlir::TF::WhileOp op,llvm::ArrayRef<mlir::Value> operands,mlir::ConversionPatternRewriter & rewriter) const1158 mlir::LogicalResult matchAndRewrite(
1159 mlir::TF::WhileOp op, llvm::ArrayRef<mlir::Value> operands,
1160 mlir::ConversionPatternRewriter &rewriter) const override {
1161 mlir::FlatSymbolRefAttr cond_fn = op.condAttr();
1162 mlir::FlatSymbolRefAttr body_fn = op.bodyAttr();
1163 // TODO(hanbinyoon): Support the parallel_iterations attribute.
1164
1165 llvm::SmallVector<Type, 4> while_arg_result_types;
1166 // Insert a chain for side effects as the first argument/result.
1167 while_arg_result_types.push_back(
1168 rewriter.getType<tfrt::compiler::ChainType>());
1169 for (mlir::Type type : op.getOperation()->getResultTypes()) {
1170 if (failed(type_converter_.convertType(type, while_arg_result_types)))
1171 return failure();
1172 }
1173
1174 // Convert the operands to either fallback tensor or corert tensors as
1175 // specified in the option.
1176 llvm::SmallVector<mlir::Value, 4> new_operands;
1177 if (mlir::failed(ConvertFunctionCallOperands(
1178 op, operands, &new_operands, rewriter, func_use_fallback_tensor_)))
1179 return failure();
1180
1181 // Create the predicate function that calls the original cond function and
1182 // in addition convert the result to a boolean value.
1183 mlir::FuncOp pred_fn =
1184 GetPredicateFunction(op, cond_fn, while_arg_result_types, rewriter);
1185 if (!pred_fn) return failure();
1186
1187 auto in_chain = corert_converter_.GetLocalSideEffectChain(op, &rewriter);
1188 auto out_chain = in_chain;
1189
1190 bool has_at_most_tensor_array_effect = HasAtMostTensorArrayEffect(op);
1191
1192 // Prepare the arguments to call the pred function for the first iteration.
1193 llvm::SmallVector<mlir::Value, 4> pred_args;
1194 pred_args.push_back(
1195 has_at_most_tensor_array_effect ? GetFunctionInputChain(op) : in_chain);
1196 pred_args.append(new_operands.begin(), new_operands.end());
1197
1198 // Insert a call op to call the pred function for the first iteration.
1199 auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1200 op.getLoc(), pred_fn.getType().getResults(), pred_fn.sym_name(),
1201 pred_args);
1202
1203 auto pred_chain = call_pred_fn.getResult(0);
1204 auto first_iteration_bool_cond = call_pred_fn.getResult(1);
1205
1206 mlir::FuncOp new_body_fn = GetWhileBodyFunction(
1207 op, body_fn, pred_fn, while_arg_result_types, rewriter);
1208
1209 // Use the pred chain as the chain to the while body. The rest args should
1210 // be the same as the pred_args.
1211 auto &while_args = pred_args;
1212 while_args[0] = pred_chain;
1213
1214 auto new_op = rewriter.create<tfrt::compiler::WhileOp>(
1215 op.getLoc(), while_arg_result_types, first_iteration_bool_cond,
1216 while_args, new_body_fn.sym_name());
1217
1218 rewriter.replaceOp(op, new_op.getResults().drop_front());
1219
1220 if (!has_at_most_tensor_array_effect) out_chain = new_op.getResult(0);
1221
1222 if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) {
1223 // Register the converted op so that it can be retrieved by successors.
1224 // TODO(chky): Add OpTraits or OpInterface, rather than assume first
1225 // result is a chain.
1226 corert_converter_.RegisterLocalSideEffectChain(op, out_chain);
1227 }
1228 return success();
1229 }
1230
1231 private:
HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const1232 bool HasAtMostTensorArrayEffect(mlir::TF::WhileOp op) const {
1233 return tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1234 op.cond_function()) &&
1235 tensor_array_side_effect_analysis_.HasAtMostTensorArrayEffect(
1236 op.body_function());
1237 }
1238
1239 mlir::FuncOp GetPredicateFunction(
1240 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1241 mlir::TypeRange arg_types,
1242 mlir::ConversionPatternRewriter &rewriter) const;
1243
1244 mlir::FuncOp GetWhileBodyFunction(
1245 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr body_fn,
1246 mlir::FuncOp pred_fn, mlir::TypeRange arg_types,
1247 mlir::ConversionPatternRewriter &rewriter) const;
1248
1249 mlir::TypeConverter &type_converter_;
1250 CoreRTConverter &corert_converter_;
1251 mlir::SymbolTable &symbol_table_;
1252 const tfrt_compiler::TensorArraySideEffectAnalysis
1253 &tensor_array_side_effect_analysis_;
1254 bool func_use_fallback_tensor_;
1255 };
1256
1257 // Create the pred function that contains a call to the original cond function
1258 // and a predicate kernel that converts the cond tensor to a boolean value. eg.
1259 //
1260 // func @pred_fn(%ch, %arg) {
1261 // %ch0, %cond_tensor = tfrt.call @original_cond_fn(%ch, %arg)
1262 // %cond_bool = tfrt_fallback_async.predicate %cond_tensor
1263 // return %ch0, %cond_bool
1264 // }
1265 //
GetPredicateFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr cond_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1266 mlir::FuncOp TFRTWhileOpConversion::GetPredicateFunction(
1267 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr cond_fn,
1268 mlir::TypeRange arg_types,
1269 mlir::ConversionPatternRewriter &rewriter) const {
1270 std::string pred_fn_name = cond_fn.getValue().str() + "/tfrt_predicate";
1271
1272 if (auto pred_fn = symbol_table_.lookup<mlir::FuncOp>(pred_fn_name)) {
1273 return pred_fn;
1274 }
1275
1276 auto func_op = op->getParentOfType<mlir::FuncOp>();
1277
1278 mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1279 rewriter.setInsertionPointAfter(func_op);
1280
1281 std::array<mlir::Type, 2> pred_result_types = {
1282 rewriter.getType<tfrt::compiler::ChainType>(), rewriter.getI1Type()};
1283
1284 auto func_type = rewriter.getFunctionType(arg_types, pred_result_types);
1285
1286 auto pred_fn =
1287 rewriter.create<mlir::FuncOp>(op.getLoc(), pred_fn_name, func_type);
1288
1289 auto *block = pred_fn.addEntryBlock();
1290 rewriter.setInsertionPointToStart(block);
1291
1292 // There are at least two arguments, with the first being !tfrt.chain and the
1293 // second being either !tfrt_fallback.tf_tensor or !corert.tensorhandle.
1294 // cond_fn must have two results. The first of them must also be !tfrt.chain
1295 // and the second must have the same tensor type as arguments. So we can just
1296 // use the first two arg types as the result types.
1297 assert(arg_types.size() >= 2);
1298 auto cond_result_types = arg_types.take_front(2);
1299
1300 auto call_cond_fn = rewriter.create<tfrt::compiler::CallOp>(
1301 op.getLoc(), cond_result_types, cond_fn, block->getArguments());
1302
1303 auto chain = call_cond_fn.getResult(0);
1304 auto cond = call_cond_fn.getResult(1);
1305
1306 auto bool_cond = GetPredicate(op, cond, rewriter);
1307 if (!bool_cond) return {};
1308
1309 llvm::SmallVector<mlir::Value, 2> results = {chain, bool_cond};
1310
1311 rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), results);
1312
1313 symbol_table_.insert(pred_fn);
1314
1315 return pred_fn;
1316 }
1317
1318 // Create the new while body function that contains a call to original while
1319 // body and then a call to the pred function. eg.
1320 //
1321 // func @new_while_body(%ch, %arg) {
1322 // %ch0, %result = tfrt.call @original_body(%ch, %arg)
1323 // %ch1, %cond_bool = tfrt.call @pred_function(%ch0, %arg)
1324 // tfrt.return %ch1, %result, %cond_bool
1325 // }
1326 //
GetWhileBodyFunction(mlir::TF::WhileOp op,mlir::FlatSymbolRefAttr original_body_fn,mlir::FuncOp pred_fn,mlir::TypeRange arg_types,mlir::ConversionPatternRewriter & rewriter) const1327 mlir::FuncOp TFRTWhileOpConversion::GetWhileBodyFunction(
1328 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr original_body_fn,
1329 mlir::FuncOp pred_fn, mlir::TypeRange arg_types,
1330 mlir::ConversionPatternRewriter &rewriter) const {
1331 std::string body_fn_name = original_body_fn.getValue().str() + "/tfrt_body";
1332
1333 if (auto body_fn = symbol_table_.lookup<mlir::FuncOp>(body_fn_name)) {
1334 return body_fn;
1335 }
1336
1337 auto func_op = op->getParentOfType<mlir::FuncOp>();
1338
1339 mlir::ConversionPatternRewriter::InsertionGuard insertion_guard(rewriter);
1340 rewriter.setInsertionPointAfter(func_op);
1341
1342 auto while_result_types = arg_types;
1343
1344 llvm::SmallVector<mlir::Type, 4> body_result_types(arg_types.begin(),
1345 arg_types.end());
1346
1347 // The last result of the while body function is the boolean condition.
1348 body_result_types.push_back(rewriter.getI1Type());
1349 auto func_type = rewriter.getFunctionType(arg_types, body_result_types);
1350 auto body_fn =
1351 rewriter.create<mlir::FuncOp>(op.getLoc(), body_fn_name, func_type);
1352
1353 auto *block = body_fn.addEntryBlock();
1354 rewriter.setInsertionPointToStart(block);
1355
1356 // Insert a call to the original body function.
1357 auto call_original_body_fn = rewriter.create<tfrt::compiler::CallOp>(
1358 op.getLoc(), while_result_types, original_body_fn, block->getArguments());
1359
1360 // Insert a call to the pred function, which contains a call to the original
1361 // cond function and the predicate kernel that converts the tensor to boolean
1362 // value.
1363 auto call_pred_fn = rewriter.create<tfrt::compiler::CallOp>(
1364 op.getLoc(), pred_fn.getType().getResults(), pred_fn.sym_name(),
1365 call_original_body_fn.getResults());
1366
1367 auto pred_chain = call_pred_fn.getResult(0);
1368
1369 llvm::SmallVector<mlir::Value, 4> body_results;
1370 // The first result should be the chain returned from the pred function.
1371 body_results.push_back(pred_chain);
1372 // Then comes the results from the orignal while body excluding the first
1373 // chain (which is replaced by the `pred_chain`).
1374 for (auto res : llvm::drop_begin(call_original_body_fn.getResults())) {
1375 body_results.push_back(res);
1376 }
1377 // The last result should be the boolean value converted from the condition.
1378 auto bool_cond = call_pred_fn.getResult(1);
1379 body_results.push_back(bool_cond);
1380
1381 rewriter.create<tfrt::compiler::ReturnOp>(op.getLoc(), body_results);
1382
1383 symbol_table_.insert(body_fn);
1384
1385 return body_fn;
1386 }
1387
1388 // TODO(ezhulenev): tf_device.cluster operations after auto-fusion should
1389 // have the correct device assigned based on the fused operations. We should
1390 // use this device to convert operands and results from/to corert handles.
1391 // For now it is safe to assume that it is "CPU" because we do not support
1392 // any other devices and do not support distributed models.
1393 constexpr char kCpuRtDevice[] = "CPU";
1394
1395 // Convert cpurt.call operations to the tf_cpurt.fallback.execute operation.
1396 class CpuRtCallToCpurtCompileAndExecuteConversion
1397 : public OpConversionPattern<tfrt::cpu::jit::CallOp> {
1398 public:
CpuRtCallToCpurtCompileAndExecuteConversion(MLIRContext * context)1399 explicit CpuRtCallToCpurtCompileAndExecuteConversion(MLIRContext *context)
1400 : OpConversionPattern<tfrt::cpu::jit::CallOp>(context) {}
1401
matchAndRewrite(tfrt::cpu::jit::CallOp call,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1402 LogicalResult matchAndRewrite(
1403 tfrt::cpu::jit::CallOp call, ArrayRef<Value> operands,
1404 ConversionPatternRewriter &rewriter) const override {
1405 // Convert operands to fallback tensors.
1406 llvm::SmallVector<Value, 4> fallback_operands;
1407 if (failed(tfrt_compiler::ConvertFallbackOperands(
1408 call, kCpuRtDevice, operands, &fallback_operands, rewriter)))
1409 return rewriter.notifyMatchFailure(call, "failed to convert operand");
1410
1411 // tf_cpurt.fallback.execute always produces fallback tensors.
1412 llvm::SmallVector<Type, 4> result_types(
1413 call->getNumResults(),
1414 rewriter.getType<tfrt::fallback::TFTensorType>());
1415
1416 // Replace cpurt.call operation with a tf_cpurt.fallback.execute operation.
1417 rewriter.replaceOpWithNewOp<tf_cpurt::FallbackExecuteOp>(
1418 call, result_types, call.callee(), fallback_operands, kCpuRtDevice);
1419
1420 return success();
1421 }
1422 };
1423
1424 // Helper function for specifying legal dialects for conversion to CoreRT.
SetUpTFToTFRTConversionLegality(mlir::ConversionTarget * target,mlir::TypeConverter * func_type_converter,mlir::Type chain_type)1425 void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target,
1426 mlir::TypeConverter *func_type_converter,
1427 mlir::Type chain_type) {
1428 target->addLegalDialect<tfrt::corert::CoreRTDialect>();
1429 target->addLegalDialect<tfrt::fallback_async::FallbackAsyncDialect>();
1430 target->addLegalDialect<tfrt::compiler::TFRTDialect>();
1431 target->addLegalDialect<tfrt::dist::DistributedDialect>();
1432 target->addLegalDialect<tfrt::test::TestDialect>();
1433 target->addLegalDialect<tf_cpurt::CpuRuntimeDialect>();
1434 target->addIllegalDialect<TF::TensorFlowDialect>();
1435 target->addIllegalDialect<tf_device::TensorFlowDeviceDialect>();
1436 target->addIllegalDialect<tfrt::cpu::jit::CpuRuntimeDialect>();
1437 target->addDynamicallyLegalOp<mlir::FuncOp>([func_type_converter,
1438 chain_type](FuncOp op) {
1439 auto func_type = op.getType();
1440 if (func_type.getNumInputs() == 0 || func_type.getInput(0) != chain_type)
1441 return false;
1442 if (func_type.getNumResults() == 0 || func_type.getResult(0) != chain_type)
1443 return false;
1444
1445 return func_type_converter->isSignatureLegal(op.getType()) &&
1446 func_type_converter->isLegal(&op.getBody());
1447 });
1448 }
1449
1450 // Helper function for inserting TFRT CpuRT dialect conversions.
PopulateCpuRtConversionPatterns(MLIRContext * context,OwningRewritePatternList * patterns,CoreRTConverter * corert_converter)1451 void PopulateCpuRtConversionPatterns(MLIRContext *context,
1452 OwningRewritePatternList *patterns,
1453 CoreRTConverter *corert_converter) {
1454 // Lower cpurt.call to the pair of compile and execute operations.
1455 patterns->insert<CpuRtCallToCpurtCompileAndExecuteConversion>(context);
1456 }
1457
1458 // Helper function for inserting TF dialect to TFRT dialect op conversion
1459 // patterns.
PopulateTFToTFRTConversionPatterns(mlir::MLIRContext * context,mlir::OwningRewritePatternList * patterns,CoreRTConverter * corert_converter,tfrt_compiler::FallbackConverter * fallback_converter,mlir::SymbolTable * symbol_table,const tfrt_compiler::CostAnalysis * cost_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis * tensor_array_side_effect_analysis,bool enable_native_ops,bool func_use_fallback_tensor,bool tpu_lower_to_fallback)1460 void PopulateTFToTFRTConversionPatterns(
1461 mlir::MLIRContext *context, mlir::OwningRewritePatternList *patterns,
1462 CoreRTConverter *corert_converter,
1463 tfrt_compiler::FallbackConverter *fallback_converter,
1464 mlir::SymbolTable *symbol_table,
1465 const tfrt_compiler::CostAnalysis *cost_analysis,
1466 const tfrt_compiler::TensorArraySideEffectAnalysis
1467 *tensor_array_side_effect_analysis,
1468 bool enable_native_ops, bool func_use_fallback_tensor,
1469 bool tpu_lower_to_fallback) {
1470 // By default, we lower all TF ops to fallback ops.
1471 patterns->insert<FallbackExecuteOpConversion>(
1472 context, corert_converter, fallback_converter, cost_analysis,
1473 tpu_lower_to_fallback);
1474 patterns->insert<FallbackConstOpConversion, FallbackSetResourceOp,
1475 FallbackGetResourceOp>(context, corert_converter);
1476
1477 // For control flow ops, we handle them according to the option.
1478 mlir::TypeConverter *func_type_converter;
1479 if (func_use_fallback_tensor) {
1480 func_type_converter = fallback_converter;
1481 } else {
1482 func_type_converter = corert_converter;
1483 }
1484 patterns->insert<TFRTFuncOpSignatureConversion>(context, func_type_converter);
1485 patterns->insert<TFRTReturnOpConversion>(context, corert_converter,
1486 func_use_fallback_tensor);
1487 patterns->insert<TFRTWhileOpConversion>(
1488 context, func_type_converter, corert_converter, symbol_table,
1489 tensor_array_side_effect_analysis, func_use_fallback_tensor);
1490 patterns->insert<TFRTCallOpConversion<mlir::TF::StatefulPartitionedCallOp>,
1491 TFRTCallOpConversion<mlir::TF::PartitionedCallOp>,
1492 TFRTCaseOpConversion, TFRTCondOpConversion>(
1493 context, func_type_converter, corert_converter, func_use_fallback_tensor);
1494
1495 // For tf.BatchFunction, we need a special fallback op to batch a BEF
1496 // function.
1497 patterns->insert<FallbackBatchFunctionOpConversion>(context,
1498 corert_converter);
1499
1500 // Below patterns are preferred over fallback lowering as we want to use
1501 // CoreRT interface for native kernels. This is only temporary and it will
1502 // refactored to use SSOT interface.
1503
1504 // Here we use specialized patterns for tf.Const on CPU as it is incorrect to
1505 // use ExecuteOp pattern to convert string tensor attribute.
1506 patterns->insert<CoreRTConstStringTensorOpConversion,
1507 CoreRTConstDenseTensorOpConversion>(context,
1508 corert_converter);
1509
1510 // For tf.Const op on GPU, we still lower it to corert.executeop temporarily.
1511 //
1512 // TODO(chky): Use specialized patterns for tf.Const ops on GPU.
1513 patterns->insert<CoreRTExecuteOpConversion<TF::ConstOp>>(
1514 context, corert_converter, kGpuDeviceName);
1515
1516 if (enable_native_ops) {
1517 // Below TF operations will be converted to use corert.executeop, which will
1518 // invoke TFRT native op if implemented.
1519 // TODO(b/187942369): Pattern registration for TF operations is not
1520 // sustainable currently. We need to figure out a plan.
1521 patterns->insert<CoreRTExecuteOpConversion<TF::AddV2Op>,
1522 // TODO(chky): Move the ReadVariableOp + Identity pattern
1523 // to optimizer.
1524 // CoreRTExecuteOpConversion<TF::IdentityOp>,
1525 CoreRTExecuteOpConversion<TF::MulOp>,
1526 CoreRTExecuteOpConversion<TF::BiasAddOp>,
1527 CoreRTExecuteOpConversion<TF::Conv2DOp>,
1528 CoreRTExecuteOpConversion<TF::ConcatV2Op>,
1529 CoreRTExecuteOpConversion<TF::CastOp>,
1530 CoreRTExecuteOpConversion<TF::ExpandDimsOp>,
1531 CoreRTExecuteOpConversion<TF::TransposeOp>,
1532 CoreRTExecuteOpConversion<TF::FusedBatchNormV3Op>,
1533 CoreRTExecuteOpConversion<TF::_FusedBatchNormExOp>,
1534 CoreRTExecuteOpConversion<TF::LogOp>,
1535 CoreRTExecuteOpConversion<TF::Log1pOp>,
1536 CoreRTExecuteOpConversion<TF::LogSoftmaxOp>,
1537 CoreRTExecuteOpConversion<TF::MatMulOp>,
1538 CoreRTExecuteOpConversion<TF::_FusedMatMulOp>,
1539 CoreRTExecuteOpConversion<TF::MaxPoolOp>,
1540 CoreRTExecuteOpConversion<TF::MeanOp>,
1541 CoreRTExecuteOpConversion<TF::MulOp>,
1542 CoreRTExecuteOpConversion<TF::PadOp>,
1543 CoreRTExecuteOpConversion<TF::RealDivOp>,
1544 CoreRTExecuteOpConversion<TF::ReluOp>,
1545 CoreRTExecuteOpConversion<TF::ReshapeOp>,
1546 CoreRTExecuteOpConversion<TF::RsqrtOp>,
1547 CoreRTExecuteOpConversion<TF::SoftmaxOp>,
1548 CoreRTExecuteOpConversion<TF::ShapeOp>,
1549 CoreRTExecuteOpConversion<TF::SigmoidOp>,
1550 CoreRTExecuteOpConversion<TF::SubOp>,
1551 CoreRTExecuteOpConversion<TF::TileOp>,
1552 CoreRTExecuteOpConversion<TF::TanhOp>,
1553 CoreRTExecuteOpConversion<TF::ZerosLikeOp>>(
1554 context, corert_converter);
1555 }
1556 }
1557
1558 // Lower TF dialect MLIR to TFRT dialect.
1559 class TfToTfrtConversionPass
1560 : public mlir::PassWrapper<TfToTfrtConversionPass,
1561 mlir::OperationPass<mlir::ModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const1562 void getDependentDialects(mlir::DialectRegistry ®istry) const override {
1563 getDependentConversionDialects(registry);
1564
1565 if (target_tpu_) RegisterTPUDialects(®istry);
1566 }
1567
getArgument() const1568 llvm::StringRef getArgument() const final { return "tf-to-tfrt"; }
getDescription() const1569 llvm::StringRef getDescription() const final {
1570 return "Convert Tensorflow dialect (generated from tf.function) to TFRT "
1571 "dialect.";
1572 }
1573
1574 public:
1575 TfToTfrtConversionPass() = default;
TfToTfrtConversionPass(const TfrtPipelineOptions & options)1576 explicit TfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1577 target_tpu_ = options.target_tpu;
1578 enable_native_ops_ = options.enable_native_ops;
1579 tpu_use_core_selector_ = options.tpu_use_core_selector;
1580 tpu_use_bundled_transfer_ = options.tpu_use_bundled_transfer;
1581 tpu_lower_to_fallback_ = options.tpu_lower_to_fallback;
1582 tpu_transfer_result_to_host_ = options.tpu_transfer_result_to_host;
1583 cost_threshold_ = options.cost_threshold;
1584 upper_cost_threshold_ = options.upper_cost_threshold;
1585 merge_inter_dependent_streams_ = options.merge_inter_dependent_streams;
1586 func_use_fallback_tensor_ = options.func_use_fallback_tensor;
1587 }
TfToTfrtConversionPass(const TfToTfrtConversionPass &)1588 TfToTfrtConversionPass(const TfToTfrtConversionPass &) {}
1589
runOnFunction(mlir::FuncOp func,const mlir::TF::SideEffectAnalysis::Info & side_effect_analysis,const tfrt_compiler::TensorArraySideEffectAnalysis & tensor_array_side_effect_analysis,tfrt_compiler::FallbackConverter & fallback_converter,mlir::SymbolTable & symbol_table)1590 mlir::LogicalResult runOnFunction(
1591 mlir::FuncOp func,
1592 const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis,
1593 const tfrt_compiler::TensorArraySideEffectAnalysis
1594 &tensor_array_side_effect_analysis,
1595 tfrt_compiler::FallbackConverter &fallback_converter,
1596 mlir::SymbolTable &symbol_table) {
1597 auto &context = getContext();
1598 mlir::ConversionTarget target(context);
1599 mlir::OwningRewritePatternList patterns(&getContext());
1600 CoreRTConverter corert_converter(&context, &side_effect_analysis);
1601 tfrt_compiler::CostAnalysis cost_analysis(func);
1602
1603 if (target_tpu_)
1604 AddTPUTargetDialectAndPatterns(
1605 &target, &patterns, &context, &corert_converter,
1606 TfrtTpuExecuteOpConversionOptions{tpu_use_core_selector_,
1607 tpu_use_bundled_transfer_,
1608 tpu_transfer_result_to_host_},
1609 tpu_lower_to_fallback_);
1610
1611 mlir::TypeConverter *func_type_converter;
1612 if (func_use_fallback_tensor_) {
1613 func_type_converter = &fallback_converter;
1614 } else {
1615 func_type_converter = &corert_converter;
1616 }
1617 SetUpTFToTFRTConversionLegality(&target, func_type_converter,
1618 corert_converter.chain_type());
1619 PopulateCpuRtConversionPatterns(&context, &patterns, &corert_converter);
1620
1621 PopulateTFToTFRTConversionPatterns(
1622 &context, &patterns, &corert_converter, &fallback_converter,
1623 &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis,
1624 enable_native_ops_, func_use_fallback_tensor_, tpu_lower_to_fallback_);
1625
1626 return mlir::applyPartialConversion(func, target, std::move(patterns));
1627 }
1628
runOnOperation()1629 void runOnOperation() override {
1630 auto module = getOperation();
1631 const auto &side_effect_analysis =
1632 getAnalysis<mlir::TF::SideEffectAnalysis>();
1633
1634 tfrt_compiler::TensorArraySideEffectAnalysis
1635 tensor_array_side_effect_analysis(module);
1636 tfrt_compiler::FallbackConverter fallback_converter(&getContext());
1637
1638 mlir::SymbolTable symbol_table(module);
1639
1640 auto func_op_range = module.getOps<mlir::FuncOp>();
1641 llvm::SmallVector<mlir::FuncOp, 4> func_ops(func_op_range.begin(),
1642 func_op_range.end());
1643 for (auto func : func_ops) {
1644 if (!func.isExternal()) {
1645 if (mlir::failed(runOnFunction(
1646 func, side_effect_analysis.GetAnalysisForFunc(func),
1647 tensor_array_side_effect_analysis, fallback_converter,
1648 symbol_table))) {
1649 signalPassFailure();
1650 return;
1651 }
1652
1653 ChainDanglingValuesinFunction(func);
1654 }
1655 }
1656
1657 CreateFallbackInitializationFunction(module, fallback_converter);
1658
1659 // Set cost threshold as a module attribute. It will be used later in the
1660 // runtime to decide whether certain execution is cheap enough to be inline
1661 // executed.
1662 mlir::Builder builder(module);
1663 module->setAttr("tfrt.cost_threshold",
1664 builder.getI64IntegerAttr(cost_threshold_));
1665 module->setAttr("tfrt.upper_cost_threshold",
1666 builder.getI64IntegerAttr(upper_cost_threshold_));
1667 module->setAttr("tfrt.merge_inter_dependent_streams",
1668 builder.getBoolAttr(merge_inter_dependent_streams_));
1669 }
1670
1671 private:
1672 // Chain all dangling values (ie. values with no users) together and merge it
1673 // with the first returned chain. This merged chain can be used to signal the
1674 // completion of all execution including side-effets.
ChainDanglingValuesinFunction(mlir::FuncOp func_op)1675 void ChainDanglingValuesinFunction(mlir::FuncOp func_op) {
1676 auto &block = func_op.front();
1677
1678 llvm::SmallVector<mlir::Value, 2> dangling_values;
1679
1680 // Find dangling function arguments.
1681 for (auto arg : block.getArguments()) {
1682 if (arg.use_empty()) {
1683 dangling_values.push_back(arg);
1684 }
1685 }
1686
1687 // Find dangling values produced by ops in the function.
1688 for (auto &op : block) {
1689 for (mlir::Value result : op.getResults()) {
1690 if (result.use_empty()) {
1691 dangling_values.push_back(result);
1692 }
1693 }
1694 }
1695
1696 if (dangling_values.empty()) return;
1697
1698 // Merge these dangling values with the first returned chain.
1699 auto return_op =
1700 llvm::cast<tfrt::compiler::ReturnOp>(block.getTerminator());
1701 auto chain = return_op->getOperand(0);
1702 assert(chain.getType().isa<tfrt::compiler::ChainType>());
1703 dangling_values.push_back(chain);
1704
1705 mlir::OpBuilder builder(return_op);
1706
1707 auto new_chain = builder.create<tfrt::compiler::MergeChainsOp>(
1708 return_op->getLoc(), builder.getType<tfrt::compiler::ChainType>(),
1709 dangling_values);
1710
1711 return_op->setOperand(0, new_chain);
1712 }
1713
CreateFallbackInitializationFunction(mlir::ModuleOp module,tfrt_compiler::FallbackConverter & fallback_converter)1714 void CreateFallbackInitializationFunction(
1715 mlir::ModuleOp module,
1716 tfrt_compiler::FallbackConverter &fallback_converter) {
1717 mlir::OpBuilder builder(&module.body());
1718
1719 auto chain_type = builder.getType<tfrt::compiler::ChainType>();
1720
1721 auto func_op = builder.create<mlir::FuncOp>(
1722 module.getLoc(), "_tfrt_fallback_init",
1723 mlir::FunctionType::get(module.getContext(), /*inputs=*/chain_type,
1724 /*results=*/chain_type));
1725
1726 auto *block = func_op.addEntryBlock();
1727 builder.setInsertionPointToStart(block);
1728
1729 mlir::Value chain_value = block->getArgument(0);
1730
1731 // Create operations for all fallback kernels in the module.
1732 for (auto *op : fallback_converter.GetFallbackOps()) {
1733 auto create_op = builder.create<tfrt::fallback_async::CreateOp>(
1734 func_op.getLoc(), chain_type, chain_value);
1735
1736 create_op->setAttrs(op->getAttrs());
1737 create_op->setAttr("num_args", builder.getI64IntegerAttr(GetNumArgs(op)));
1738
1739 chain_value = create_op;
1740 }
1741
1742 // Pre-compile all JIT compiled kernels found in the module.
1743 llvm::SmallVector<Value> compiled;
1744
1745 // A set SymbolRef attributes referencing compiled kernels.
1746 llvm::DenseSet<mlir::Attribute> kernels;
1747
1748 // Compile all kernels in parallell.
1749 module.walk([&](tf_cpurt::FallbackExecuteOp execute) {
1750 // Do not compiled the same kernel multiple times.
1751 if (kernels.contains(execute.kernel())) return;
1752
1753 auto compile = builder.create<tf_cpurt::FallbackCompileOp>(
1754 execute.getLoc(), chain_type, execute.kernel(), execute.device());
1755 compiled.push_back(compile.getResult());
1756 kernels.insert(compile.kernel());
1757 });
1758
1759 // Wait for the compilation completion before returning from init function.
1760 if (compiled.size() == 1) {
1761 chain_value = compiled[0];
1762 } else if (compiled.size() > 1) {
1763 chain_value = builder.create<tfrt::compiler::MergeChainsOp>(
1764 func_op.getLoc(), chain_type, compiled);
1765 }
1766
1767 builder.create<tfrt::compiler::ReturnOp>(func_op.getLoc(), chain_value);
1768 }
1769
GetNumArgs(mlir::Operation * fallback_op)1770 int64_t GetNumArgs(mlir::Operation *fallback_op) {
1771 if (auto execute_op =
1772 llvm::dyn_cast<tfrt::fallback_async::ExecuteOp>(fallback_op)) {
1773 return execute_op.operands().size();
1774 } else if (auto execute_op_seq =
1775 llvm::dyn_cast<tfrt::fallback_async::ExecuteOpSeq>(
1776 fallback_op)) {
1777 return execute_op_seq.operands().size();
1778 }
1779 llvm_unreachable("invalid fallback op type");
1780 }
1781
1782 Option<bool> target_tpu_{*this, "target-tpu",
1783 llvm::cl::desc("Target TPU programs if true."),
1784 llvm::cl::init(false)};
1785
1786 Option<bool> enable_native_ops_{
1787 *this, "enable-native-ops",
1788 llvm::cl::desc(
1789 "If true, native ops will be used on an opt-in basis "
1790 "instead of fallback ops. If false, no native ops are used."),
1791 llvm::cl::init(true)};
1792
1793 Option<bool> tpu_use_core_selector_{
1794 *this, "tpu-use-core-selector",
1795 llvm::cl::desc("If true, use ServingCoreSelector to pick TPU core. "
1796 "Otherwise, use the assigned core."),
1797 llvm::cl::init(true)};
1798
1799 Option<bool> tpu_use_bundled_transfer_{
1800 *this, "tpu-use-bundled-transfer",
1801 llvm::cl::desc("If true, use BundledTransferToTpuOp to transfer "
1802 "variables and input tensors to TPU."),
1803 llvm::cl::init(true)};
1804
1805 Option<bool> tpu_lower_to_fallback_{
1806 *this, "tpu-lower-to-fallback",
1807 llvm::cl::desc("If true, lower an TF op that's placed on TPU device "
1808 "to be executed by tfrt_fallback.execute."),
1809 llvm::cl::init(true)};
1810
1811 // TODO(b/194081364): remove this option once we unify servo TPU serving
1812 // result transfer behavior.
1813 Option<bool> tpu_transfer_result_to_host_{
1814 *this, "tpu-transfer-result-to-host",
1815 llvm::cl::desc("If true, transfer the result of tpurt.execute from TPU "
1816 "to host."),
1817 llvm::cl::init(true)};
1818
1819 Option<uint64_t> cost_threshold_{
1820 *this, "tfrt-cost-threshold",
1821 llvm::cl::desc(
1822 "The cost threshold to decide whether a sequence of operations is "
1823 "cheap, and then whether it can be executed inline."),
1824 llvm::cl::init(1)};
1825
1826 Option<int64_t> upper_cost_threshold_{
1827 *this, "tfrt-upper-cost-threshold",
1828 llvm::cl::desc(
1829 "The threshold to limit the merging of dependent sequence."),
1830 llvm::cl::init(-1)};
1831
1832 Option<bool> merge_inter_dependent_streams_{
1833 *this, "tfrt-merge-inter-dependent-streams",
1834 llvm::cl::desc("If true, streams with inter data depenedencies will be "
1835 "preferred to be merged for inline execution."),
1836 llvm::cl::init(false)};
1837
1838 Option<bool> func_use_fallback_tensor_{
1839 *this, "func-use-fallback-tensor",
1840 llvm::cl::desc(
1841 "If true, use TF tensor as input/output types in func (and other "
1842 "control flow) ops."),
1843 llvm::cl::init(false)};
1844 };
1845
1846 } // namespace
1847
1848 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTfToTfrtConversionPass(const TfrtPipelineOptions & options)1849 CreateTfToTfrtConversionPass(const TfrtPipelineOptions &options) {
1850 return std::make_unique<TfToTfrtConversionPass>(options);
1851 }
1852
1853 // -------------------------------------------------------------------------- //
1854 // Outline tf_device.cluster operation regions into functions in the nested
1855 // modules and replaces all cluster operations with cpurt.call operations.
1856 // -------------------------------------------------------------------------- //
1857
1858 class OutlineCpuRtClustersPass
1859 : public PassWrapper<OutlineCpuRtClustersPass, OperationPass<ModuleOp>> {
1860 public:
getArgument() const1861 llvm::StringRef getArgument() const final {
1862 return "tf-outline-cpurt-cluster";
1863 }
getDescription() const1864 llvm::StringRef getDescription() const final {
1865 return "Outlines `tf_device.cluster` operations into functions and "
1866 "replaces them with `cpurt.call` operations.";
1867 }
1868
1869 void runOnOperation() override;
1870
getDependentDialects(mlir::DialectRegistry & registry) const1871 void getDependentDialects(mlir::DialectRegistry ®istry) const override {
1872 registry.insert<tfrt::cpu::jit::CpuRuntimeDialect>();
1873 }
1874
1875 private:
1876 struct CompiledModule {
1877 ModuleOp module;
1878 FuncOp entrypoint;
1879 llvm::SetVector<Value> operands;
1880 };
1881
1882 // Creates a nested module with a single function that will be compiled into
1883 // the kernel at runtime.
1884 CompiledModule CreateCompiledModule(tf_device::ClusterOp cluster,
1885 SymbolTable *symbol_table);
1886
1887 // Update compiled module entrypoint signature with inferred operands
1888 // constraints.
1889 LogicalResult SetEntrypointConstraints(CompiledModule &compiled);
1890
1891 // Outlines cluster operation regions into compiled modules, and replaces
1892 // cluster operation with a cpurt.call operation.
1893 LogicalResult OutlineClusterOp(tf_device::ClusterOp cluster,
1894 SymbolTable *symbol_table);
1895
1896 // Mapping from the outlined module string representation to the module itself
1897 // and an entrypoint function. Used to deduplicate identical modules during
1898 // the `tf_device.cluster` outlining.
1899 llvm::StringMap<std::pair<ModuleOp, FuncOp>> outlined_;
1900 };
1901
1902 OutlineCpuRtClustersPass::CompiledModule
CreateCompiledModule(tf_device::ClusterOp cluster,SymbolTable * symbol_table)1903 OutlineCpuRtClustersPass::CreateCompiledModule(tf_device::ClusterOp cluster,
1904 SymbolTable *symbol_table) {
1905 MLIRContext *ctx = cluster->getContext();
1906 Location loc = cluster.getLoc();
1907
1908 // Create a module that will hold compiled function and async wrappers.
1909 // TODO(ezhulenev): Give better names to module and function.
1910 auto compiled_module = ModuleOp::create(loc, {"kernel"});
1911 compiled_module->setAttr("tfrt.compiled", UnitAttr::get(ctx));
1912
1913 SymbolTable compiled_module_symbol_table(compiled_module);
1914
1915 // Find out the cluster arguments and their types.
1916 llvm::SetVector<Value> live_ins;
1917 getUsedValuesDefinedAbove(cluster.body(), cluster.body(), live_ins);
1918
1919 llvm::SmallVector<Type, 4> operand_types;
1920 operand_types.reserve(live_ins.size());
1921 for (Value v : live_ins) operand_types.emplace_back(v.getType());
1922
1923 // Create a function in the compiled module.
1924 auto compiled_func_type =
1925 FunctionType::get(ctx, operand_types, cluster->getResultTypes());
1926 auto compiled_func = FuncOp::create(loc, "compute", compiled_func_type);
1927 compiled_module_symbol_table.insert(compiled_func);
1928
1929 // Replace uses of live-in values within cluster region with block arguments.
1930 Block *compiled_func_block = compiled_func.addEntryBlock();
1931 for (auto p : llvm::zip(live_ins, compiled_func_block->getArguments()))
1932 replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), cluster.body());
1933
1934 // Move all operations in cluster into compiled_func's entry block.
1935 auto &cluster_body = cluster.GetBody().getOperations();
1936 compiled_func_block->getOperations().splice(
1937 compiled_func_block->end(), cluster_body, cluster_body.begin(),
1938 cluster_body.end());
1939
1940 // Replace `tf_device.return` terminator with `return` in the function body.
1941 auto device_return =
1942 cast<tf_device::ReturnOp>(compiled_func_block->getTerminator());
1943 OpBuilder builder(device_return.getOperation());
1944 builder.create<ReturnOp>(device_return.getLoc(), device_return.getOperands());
1945 device_return.erase();
1946
1947 // TODO(ezhulenev): MLIR doesn't define operation equivalence upstream yet,
1948 // replace module printing with a more principled solution when available.
1949 // Operations in the cluster can be in different order, however define the
1950 // identical Tensorflow programs, with current approach we'll not be able
1951 // to detect duplicates like this.
1952
1953 // Serialize prepared module to string.
1954 std::string serialized;
1955 llvm::raw_string_ostream os(serialized);
1956 compiled_module.print(os);
1957
1958 // Try to find if identical module was already outlined.
1959 auto it = outlined_.find(serialized);
1960
1961 // Return identical module that was already outlined earlier.
1962 if (it != outlined_.end()) {
1963 compiled_module.erase(); // erase identical module
1964 return {it->second.first, it->second.second, live_ins};
1965 }
1966
1967 // Insert compiled module into the symbol table and assign it a unique name.
1968 symbol_table->insert(compiled_module);
1969
1970 // Cache unique module.
1971 outlined_.insert({std::move(serialized), {compiled_module, compiled_func}});
1972
1973 return {compiled_module, compiled_func, live_ins};
1974 }
1975
SetEntrypointConstraints(CompiledModule & compiled)1976 LogicalResult OutlineCpuRtClustersPass::SetEntrypointConstraints(
1977 CompiledModule &compiled) {
1978 FuncOp func = compiled.entrypoint;
1979
1980 // Functions outlined from cpurt device clusters must have a single block.
1981 assert(func.body().getBlocks().size() == 1 && "expected single block");
1982
1983 mlir::TFDevice::ClusteringPolicySet policies;
1984 populateTfCpurtConstraintsPolicies(policies);
1985
1986 // Infer constraints on the values defined in the entrypoint function
1987 // (including function entry block arguments).
1988 mlir::TFDevice::ValuesConstraintSet constraints;
1989 if (failed(mlir::TFDevice::PropagateValuesConstraints(
1990 func.body(), policies, constraints, /*resolve=*/true)))
1991 return failure();
1992
1993 // Annotate arguments with inferred constraints.
1994 for (unsigned i = 0; i < func.getNumArguments(); ++i) {
1995 if (auto constraint = constraints.GetConstraint(func.getArgument(i))) {
1996 auto constraint_name = mlir::StringAttr::get(
1997 &getContext(), llvm::formatv("{0}", *constraint).str());
1998 func.setArgAttr(i, "cpurt.constraint", constraint_name);
1999 }
2000 }
2001
2002 return success();
2003 }
2004
OutlineClusterOp(tf_device::ClusterOp cluster,SymbolTable * symbol_table)2005 LogicalResult OutlineCpuRtClustersPass::OutlineClusterOp(
2006 tf_device::ClusterOp cluster, SymbolTable *symbol_table) {
2007 Location loc = cluster->getLoc();
2008 OpBuilder builder(cluster);
2009
2010 CompiledModule compiled_module = CreateCompiledModule(cluster, symbol_table);
2011 FuncOp compiled_func = compiled_module.entrypoint;
2012
2013 // Add constraints to the entrypoint arguments.
2014 if (failed(SetEntrypointConstraints(compiled_module))) return failure();
2015
2016 // Replace device cluster with a cpurt.call operation.
2017 auto module_name = *compiled_module.module.sym_name();
2018 auto func_name = compiled_func.sym_name();
2019 auto func_flat_ref = builder.getSymbolRefAttr(func_name);
2020 auto func_ref = builder.getSymbolRefAttr(module_name, {func_flat_ref});
2021
2022 auto cluster_func_op = builder.create<tfrt::cpu::jit::CallOp>(
2023 loc, cluster.getResultTypes(), func_ref,
2024 compiled_module.operands.getArrayRef());
2025
2026 cluster.replaceAllUsesWith(cluster_func_op);
2027 cluster.erase();
2028
2029 return success();
2030 }
2031
runOnOperation()2032 void OutlineCpuRtClustersPass::runOnOperation() {
2033 ModuleOp module = getOperation();
2034 SymbolTable symbol_table(module);
2035
2036 OpBuilder builder(module.getContext());
2037 auto result = module.walk([&](tf_device::ClusterOp cluster) -> WalkResult {
2038 // Ensure that cluster was formed for TFRT JIT compilation.
2039 auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2040 if (!policy || policy.getValue() != "tfrt.auto-fusion")
2041 return WalkResult::advance();
2042
2043 if (failed(OutlineClusterOp(cluster, &symbol_table)))
2044 return WalkResult::interrupt();
2045 return WalkResult::advance();
2046 });
2047
2048 if (result.wasInterrupted()) {
2049 module->emitError("Failed to outline tf_device.cluster operations");
2050 signalPassFailure();
2051 }
2052 }
2053
CreateOutlineCpuRtClustersPass()2054 std::unique_ptr<Pass> CreateOutlineCpuRtClustersPass() {
2055 return std::make_unique<OutlineCpuRtClustersPass>();
2056 }
2057
2058 // -------------------------------------------------------------------------- //
2059
CreateTFExecutorToTFPipeline(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2060 void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm,
2061 const TfrtPipelineOptions &options) {
2062 // Due to b/191304670, functionalized while ops might not have the
2063 // shape_invariant attribute set correctly, which leads to failure in shape
2064 // inference. As a workaround, we conservatively (e.g., we place less
2065 // restrictions on tf.while which will avoid failures but lead to potentially
2066 // less exact shape inference) set the shape_invariant attribute in all
2067 // tf.While ops before performing shape inference.
2068 //
2069 // Note that this pass might not work well with TF XLA bridge, but this is
2070 // fine as TF XLA bridge is run before this pipeline. For CPU ops, less exact
2071 // shape inference may lead to fewer optimizations but it should be fine as it
2072 // is limited to while ops currently.
2073 //
2074 // TODO(b/191304670): Remove this pass once the shape_invariant attribute is
2075 // set correctly in the upstream.
2076 pm.addNestedPass<mlir::FuncOp>(
2077 tfrt_compiler::CreateSetShapeInvariantInWhileOps());
2078
2079 // We pass the MLIR module through the TF standard pipeline, which for
2080 // instances does shape inference, canonicalization, inlining, etc.
2081 mlir::TF::StandardPipelineOptions tf_options;
2082 tf_options.enable_inliner = true;
2083 mlir::TF::CreateTFStandardPipeline(pm, tf_options);
2084
2085 pm.addPass(mlir::TF::CreateConstantOpDeviceAssignmentPass());
2086
2087 // After the standard pass, we now have MLIR in TF dialect, and now we convert
2088 // reference variable to resource variables, which is besteffort.
2089 pm.addPass(CreateConvertReferenceVariableToResourceVariablePass());
2090
2091 // Move the tf.Assert op to the end of the function, so that it does not
2092 // impose unnecessary control dependencies on other ops.
2093 pm.addPass(tfrt_compiler::CreateReorderTfAssertPass());
2094
2095 // Optimze the side-effects of control flow ops by examining the ops in its
2096 // callees.
2097 pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass());
2098
2099 // Remove tf.If ops' operands that are produced by tf.Const ops.
2100 pm.addPass(tfrt_compiler::CreateRemoveTfIfConstArgsPass());
2101
2102 // Merge non-side-effecting tf.If ops if their operands are the same.
2103 pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass());
2104
2105 // Deduplicate functions invoked by tf.BatchFunction with the same
2106 // shared_name
2107 pm.addPass(
2108 tfrt_compiler::CreateDeduplicateFunctionsInovkedByBatchFunctionPass());
2109
2110 // Apply standard optimization after optimizing control flow ops.
2111 pm.addPass(mlir::createInlinerPass());
2112 pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
2113
2114 // TODO(b/187876545): An extra shape inference pass is added because it does
2115 // not work well with tf.Identity op that remove ref type. So we work around
2116 // by performing shape inference again after reference variable to resource
2117 // variable conversion. We should remove this after b/187876545 is fixed.
2118 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
2119
2120 pm.addNestedPass<mlir::FuncOp>(
2121 mlir::TFDevice::CreateLaunchToDeviceAttributePass());
2122
2123 // After all standard passes run layout optimization to assign optimal data
2124 // format for all layout sensitive operations.
2125 mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
2126 layout_optimization_options.force_data_format =
2127 options.force_data_format.getValue();
2128 // TODO(b/191304261): Folding transpose in ops is buggy in the layout
2129 // optimization pass. Disable it to avoid errors in b/191304261. This should
2130 // not affect CPU performance as it does not change the number of ops, nor
2131 // does it change the types of the ops.
2132 layout_optimization_options.skip_fold_transpose_in_ops = true;
2133 mlir::TF::CreateLayoutOptimizationPipeline(pm.nest<mlir::FuncOp>(),
2134 layout_optimization_options);
2135
2136 // Run canonicalization pipeline to remove unused constants and bypassed
2137 // transpose operations left in the IR after layout optimization.
2138 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
2139
2140 // Decompose resource ops as resource variables will be converted to tensors
2141 // directly. Only do use for non-TPU programs.
2142 if (options.decompose_resource_ops && !options.target_tpu)
2143 pm.addNestedPass<mlir::FuncOp>(
2144 mlir::TFDevice::CreateDecomposeResourceOpsPass());
2145
2146 // Then we assign default devices.
2147 pm.addNestedPass<mlir::FuncOp>(
2148 mlir::TF::CreateSimpleTFDeviceAssignmentPass(options.default_device));
2149
2150 pm.addNestedPass<mlir::FuncOp>(
2151 mlir::TF::CreateTensorDeviceCopyConversionPass());
2152
2153 // Outline auto-fusion clusters into tf_device.cluster_operations and then
2154 // convert them to functions. We currently support only tfrt fallback tensors
2155 // as operands, so we disable these passes if we can have native ops after
2156 // lowering.
2157 if (!options.enable_native_ops) {
2158 pm.addNestedPass<mlir::FuncOp>(CreateTfCpurtClusteringPass(
2159 options.auto_fusion_oplist, options.auto_fusion_min_cluster_size));
2160
2161 // Sink small constants into the outlined clusters to reduce the number of
2162 // arguments for each of the execute operations.
2163 auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster,
2164 mlir::ElementsAttr value) -> bool {
2165 // Ensure that cluster was formed for TFRT JIT compilation.
2166 auto policy = cluster->getAttr("policy").dyn_cast_or_null<StringAttr>();
2167 if (!policy || policy.getValue() != "tfrt.auto-fusion") return false;
2168
2169 // Check that TF->CPURT compiler supports constant compilation.
2170 return mlir::succeeded(IsCompilableConstant(value));
2171 };
2172
2173 pm.addNestedPass<mlir::FuncOp>(
2174 mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const));
2175
2176 // Outline formed JIT compiled device clusters into function.
2177 pm.addPass(CreateOutlineCpuRtClustersPass());
2178 }
2179
2180 // Rewriter operation sequences to device specific fusions.
2181 DeviceNameUtils::ParsedName parsed_name;
2182
2183 // Ignore error.
2184 bool success =
2185 DeviceNameUtils::ParseFullName(options.default_device, &parsed_name);
2186 assert(success && "default device is invalid");
2187 (void)success;
2188
2189 if (parsed_name.has_type && parsed_name.type == DEVICE_GPU)
2190 pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateGpuOpFusionPass());
2191
2192 if (parsed_name.has_type && parsed_name.type == DEVICE_CPU)
2193 pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateFusedKernelMatcherPass());
2194
2195 pm.addPass(CreateLowerTFSavedModelPass(options.hoist_invariant_ops));
2196 }
2197
CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager & pm,const TfrtPipelineOptions & options)2198 void CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager &pm,
2199 const TfrtPipelineOptions &options) {
2200 CreateTFExecutorToTFPipeline(pm, options);
2201
2202 pm.addPass(CreateTfToTfrtConversionPass(options));
2203
2204 pm.addPass(CreateRemoveDeviceAttributePass());
2205
2206 // Run optimizer on the MLIR module in CoreRT dialect.
2207 if (options.enable_optimizer) {
2208 pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
2209 pm.addPass(mlir::createInlinerPass());
2210 pm.addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
2211 pm.addNestedPass<mlir::FuncOp>(
2212 tfrt_compiler::CreateInsertFallbackTensorCopyPass());
2213 }
2214 }
2215
2216 // If verbose logging is on, dump the output of each pass to a file directory,
2217 // set via env var TF_DUMP_GRAPH_PREFIX. e.g.:
2218 // export TF_DUMP_GRAPH_PREFIX=/tmp/mlir
CreateTfExecutorToTfrtPipeline(mlir::PassManager & pm,const TfrtPipelineOptions & options)2219 void CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm,
2220 const TfrtPipelineOptions &options) {
2221 if (VLOG_IS_ON(1)) {
2222 // Print the whole module after each pass, which requires disabling
2223 // multi-threading as well.
2224 pm.getContext()->disableMultithreading();
2225 pm.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
2226 /*print_module_scope=*/true));
2227 }
2228 CreateTfExecutorToTfrtPipelineHelper(pm, options);
2229 }
2230
2231 static mlir::PassRegistration<TfToTfrtConversionPass> tf_to_tfrt_pass;
2232
2233 static mlir::PassRegistration<OutlineCpuRtClustersPass>
2234 tf_outline_cpurt_cluster_pass(CreateOutlineCpuRtClustersPass);
2235
2236 static mlir::PassPipelineRegistration<TfrtPipelineOptions> tf_pipeline(
2237 "tf-executor-to-tfrt-pipeline",
2238 "Convert Tensorflow Executor dialect to TFRT dialect and "
2239 "also apply necessary optimization passes.",
2240 CreateTfExecutorToTfrtPipelineHelper);
2241
2242 } // namespace tensorflow
2243